aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 17:31:52 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 17:31:52 -0800
commit4d165c75c3880c20d698f5ba95c8d16ecc93cea7 (patch)
tree6a47c4218c503841af0f975c49e3d36e7f510bc7
parent054b515feec0a3fca4cfb1f29adbf423c9027c3a (diff)
parenteabee43e39186d91956277a4fc8b0dd566a68e3b (diff)
Merge remote-tracking branch 'staging/master' into pushsync
-rw-r--r--configure.py22
-rw-r--r--tensorflow/BUILD119
-rw-r--r--tensorflow/c/c_api.cc4
-rw-r--r--tensorflow/c/c_api_test.cc4
-rw-r--r--tensorflow/c/eager/BUILD2
-rw-r--r--tensorflow/c/eager/tape.cc102
-rw-r--r--tensorflow/c/eager/tape.h501
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc1
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc3
-rw-r--r--tensorflow/compiler/tests/BUILD15
-rw-r--r--tensorflow/compiler/tests/cholesky_op_test.py126
-rw-r--r--tensorflow/compiler/tf2xla/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc114
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cholesky_op.cc39
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD120
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.cc154
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.h51
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc166
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.h38
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc175
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.h46
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc69
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc107
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.h54
-rw-r--r--tensorflow/compiler/tf2xla/type_util.cc3
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h6
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc19
-rw-r--r--tensorflow/compiler/xla/BUILD2
-rw-r--r--tensorflow/compiler/xla/array.h159
-rw-r--r--tensorflow/compiler/xla/array_test.cc45
-rw-r--r--tensorflow/compiler/xla/client/client.cc3
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h1
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc57
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.h4
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc20
-rw-r--r--tensorflow/compiler/xla/client/local_client.h16
-rw-r--r--tensorflow/compiler/xla/literal_util.cc121
-rw-r--r--tensorflow/compiler/xla/literal_util.h25
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc62
-rw-r--r--tensorflow/compiler/xla/primitive_util.cc8
-rw-r--r--tensorflow/compiler/xla/primitive_util.h7
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/backend.cc4
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc21
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.cc16
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc569
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.h37
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc17
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.h11
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc10
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/layout_assignment.cc4
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h6
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h10
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc22
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc34
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc65
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc48
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h15
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc26
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc118
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h31
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc74
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h87
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc73
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc49
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc2
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc2
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc32
-rw-r--r--tensorflow/compiler/xla/service/liveness_util.cc6
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD24
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc65
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h128
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc68
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h29
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc8
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.h2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc150
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h174
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc20
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.cc38
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/service.h2
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc14
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.cc8
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.h4
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc90
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc45
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc6
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc4
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc6
-rw-r--r--tensorflow/compiler/xla/shape_tree.h157
-rw-r--r--tensorflow/compiler/xla/shape_tree_test.cc36
-rw-r--r--tensorflow/compiler/xla/shape_util.cc1
-rw-r--r--tensorflow/compiler/xla/shape_util.h3
-rw-r--r--tensorflow/compiler/xla/tests/BUILD24
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc54
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h23
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/compilation_cache_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc160
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc97
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc13
-rw-r--r--tensorflow/compiler/xla/tests/llvm_compiler_test.cc143
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc3
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/test_macros.h6
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc120
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h64
-rw-r--r--tensorflow/compiler/xla/tools/BUILD1
-rw-r--r--tensorflow/compiler/xla/tools/parser/README.md16
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_lexer.cc90
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_lexer.h8
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc1141
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc448
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_token.h4
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc1
-rw-r--r--tensorflow/compiler/xla/types.h3
-rw-r--r--tensorflow/compiler/xla/window_util.cc28
-rw-r--r--tensorflow/compiler/xla/xla_data.proto26
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/__init__.py1
-rw-r--r--tensorflow/contrib/android/asset_manager_filesystem.cc4
-rw-r--r--tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h10
-rw-r--r--tensorflow/contrib/boosted_trees/ops/prediction_ops.cc2
-rw-r--r--tensorflow/contrib/cmake/tf_c.cmake2
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake3
-rw-r--r--tensorflow/contrib/cmake/tf_core_kernels.cmake1
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake1
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake15
-rw-r--r--tensorflow/contrib/data/BUILD13
-rw-r--r--tensorflow/contrib/data/__init__.py4
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc232
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py225
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py77
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py78
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py27
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD40
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py89
-rw-r--r--tensorflow/contrib/data/python/ops/dataset_ops.py8
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py2
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/iterator_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py2
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py2
-rw-r--r--tensorflow/contrib/eager/README.md2
-rw-r--r--tensorflow/contrib/eager/python/network.py63
-rw-r--r--tensorflow/contrib/eager/python/network_test.py108
-rw-r--r--tensorflow/contrib/estimator/BUILD5
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py143
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py206
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head.py67
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head_test.py188
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py43
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py96
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py2
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py17
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py8
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions.py62
-rw-r--r--tensorflow/contrib/learn/BUILD2
-rw-r--r--tensorflow/contrib/lite/BUILD280
-rw-r--r--tensorflow/contrib/lite/allocation.cc122
-rw-r--r--tensorflow/contrib/lite/allocation.h94
-rw-r--r--tensorflow/contrib/lite/build_def.bzl233
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h164
-rw-r--r--tensorflow/contrib/lite/context.c92
-rw-r--r--tensorflow/contrib/lite/context.h298
-rw-r--r--tensorflow/contrib/lite/context_test.cc74
-rw-r--r--tensorflow/contrib/lite/error_reporter.cc50
-rw-r--r--tensorflow/contrib/lite/error_reporter.h54
-rw-r--r--tensorflow/contrib/lite/interpreter.cc567
-rw-r--r--tensorflow/contrib/lite/interpreter.h376
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc526
-rw-r--r--tensorflow/contrib/lite/java/BUILD111
-rw-r--r--tensorflow/contrib/lite/java/demo/.gitignore9
-rw-r--r--tensorflow/contrib/lite/java/demo/app/build.gradle58
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml42
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/BUILD47
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD26
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt1001
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java72
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java708
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java35
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java184
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.pngbin0 -> 490 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.pngbin0 -> 3136 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.pngbin0 -> 116 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.pngbin0 -> 320 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.pngbin0 -> 1915 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.pngbin0 -> 611 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.pngbin0 -> 4294 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.pngbin0 -> 952 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.pngbin0 -> 7279 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml50
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml22
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml45
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml24
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml25
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml22
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-colors.xml21
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-template-styles.xml24
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml30
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml19
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml24
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml18
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-dimens.xml32
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-styles.xml42
-rw-r--r--tensorflow/contrib/lite/java/demo/build.gradle23
-rw-r--r--tensorflow/contrib/lite/java/demo/gradle.properties17
-rw-r--r--tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jarbin0 -> 53636 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties6
-rwxr-xr-xtensorflow/contrib/lite/java/demo/gradlew160
-rw-r--r--tensorflow/contrib/lite/java/demo/gradlew.bat90
-rw-r--r--tensorflow/contrib/lite/java/demo/settings.gradle1
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java76
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java172
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java276
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java71
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java44
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java17
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/BUILD108
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc29
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/exception_jni.cc66
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/exception_jni.h50
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc446
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h151
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc242
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.h74
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc26
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h36
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/version_script.lds11
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java34
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java221
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java406
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java32
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java105
-rw-r--r--tensorflow/contrib/lite/java/src/testdata/add.binbin0 -> 476 bytes
-rw-r--r--tensorflow/contrib/lite/java/src/testdata/float32.binbin0 -> 388 bytes
-rw-r--r--tensorflow/contrib/lite/java/src/testdata/int32.binbin0 -> 396 bytes
-rw-r--r--tensorflow/contrib/lite/java/src/testdata/int64.binbin0 -> 396 bytes
-rw-r--r--tensorflow/contrib/lite/java/src/testdata/invalid_model.bin1
-rw-r--r--tensorflow/contrib/lite/java/src/testdata/uint8.binbin0 -> 396 bytes
-rw-r--r--tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD30
-rw-r--r--tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java35
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD408
-rw-r--r--tensorflow/contrib/lite/kernels/activation_functor.h58
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc389
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc323
-rw-r--r--tensorflow/contrib/lite/kernels/add.cc184
-rw-r--r--tensorflow/contrib/lite/kernels/add_test.cc171
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn.cc161
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn_test.cc267
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation.cc200
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation_test.cc162
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc425
-rw-r--r--tensorflow/contrib/lite/kernels/conv_test.cc440
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc289
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv_test.cc186
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc104
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc248
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc166
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_test.cc94
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc307
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected_test.cc377
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.cc68
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.h54
-rw-r--r--tensorflow/contrib/lite/kernels/hashtable_lookup.cc155
-rw-r--r--tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc176
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD359
-rw-r--r--tensorflow/contrib/lite/kernels/internal/common.h107
-rw-r--r--tensorflow/contrib/lite/kernels/internal/compatibility.h78
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h65
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h987
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h1916
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h231
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h143
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h167
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h195
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc337
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h113
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h3715
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h138
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.cc95
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.h55
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc108
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h115
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h138
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc165
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h189
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h2455
-rw-r--r--tensorflow/contrib/lite/kernels/internal/round.h39
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h87
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_test.cc55
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.cc27
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.h116
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc192
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h81
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.cc87
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.h65
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm.cc112
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm_test.cc63
-rw-r--r--tensorflow/contrib/lite/kernels/local_response_norm.cc109
-rw-r--r--tensorflow/contrib/lite/kernels/local_response_norm_test.cc101
-rw-r--r--tensorflow/contrib/lite/kernels/lsh_projection.cc204
-rw-r--r--tensorflow/contrib/lite/kernels/lsh_projection_test.cc123
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc515
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_test.cc1088
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc167
-rw-r--r--tensorflow/contrib/lite/kernels/mul_test.cc127
-rw-r--r--tensorflow/contrib/lite/kernels/op_macros.h32
-rw-r--r--tensorflow/contrib/lite/kernels/optional_tensor_test.cc343
-rw-r--r--tensorflow/contrib/lite/kernels/padding.h28
-rw-r--r--tensorflow/contrib/lite/kernels/pooling.cc355
-rw-r--r--tensorflow/contrib/lite/kernels/pooling_test.cc161
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc109
-rw-r--r--tensorflow/contrib/lite/kernels/register.h50
-rw-r--r--tensorflow/contrib/lite/kernels/reshape.cc91
-rw-r--r--tensorflow/contrib/lite/kernels/reshape_test.cc90
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear.cc129
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear_test.cc117
-rw-r--r--tensorflow/contrib/lite/kernels/skip_gram.cc160
-rw-r--r--tensorflow/contrib/lite/kernels/skip_gram_test.cc257
-rw-r--r--tensorflow/contrib/lite/kernels/softmax_test.cc143
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_depth.cc146
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_depth_test.cc102
-rw-r--r--tensorflow/contrib/lite/kernels/svdf.cc224
-rw-r--r--tensorflow/contrib/lite/kernels/svdf_test.cc312
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.cc183
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h202
-rw-r--r--tensorflow/contrib/lite/model.cc673
-rw-r--r--tensorflow/contrib/lite/model.h165
-rw-r--r--tensorflow/contrib/lite/model_test.cc258
-rw-r--r--tensorflow/contrib/lite/models/smartreply/BUILD15
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc119
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc100
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/normalize.cc105
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc90
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/predict.cc174
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc183
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor.cc116
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor.h80
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor_test.cc150
-rw-r--r--tensorflow/contrib/lite/models/speech_hotword_model_test.cc115
-rw-r--r--tensorflow/contrib/lite/models/speech_speakerid_model_test.cc114
-rw-r--r--tensorflow/contrib/lite/models/speech_terse_am_model_test.cc127
-rw-r--r--tensorflow/contrib/lite/models/speech_tts_model_test.cc116
-rw-r--r--tensorflow/contrib/lite/models/test_utils.h84
-rw-r--r--tensorflow/contrib/lite/nnapi/BUILD25
-rw-r--r--tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h1916
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc386
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.h66
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.cc108
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.h32
-rw-r--r--tensorflow/contrib/lite/python/BUILD46
-rw-r--r--tensorflow/contrib/lite/python/lite.py199
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py45
-rw-r--r--tensorflow/contrib/lite/schema/BUILD82
-rw-r--r--tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc91
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs346
-rw-r--r--tensorflow/contrib/lite/schema/schema_v0.fbs247
-rw-r--r--tensorflow/contrib/lite/schema/schema_v1.fbs295
-rw-r--r--tensorflow/contrib/lite/schema/schema_v2.fbs303
-rw-r--r--tensorflow/contrib/lite/schema/schema_v3.fbs326
-rw-r--r--tensorflow/contrib/lite/schema/upgrade_schema.py341
-rw-r--r--tensorflow/contrib/lite/schema/upgrade_schema_test.py317
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena.cc136
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena.h84
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena_test.cc91
-rw-r--r--tensorflow/contrib/lite/string.h30
-rw-r--r--tensorflow/contrib/lite/string_util.cc117
-rw-r--r--tensorflow/contrib/lite/string_util.h91
-rw-r--r--tensorflow/contrib/lite/string_util_test.cc117
-rw-r--r--tensorflow/contrib/lite/testdata/0_subgraphs.binbin0 -> 80 bytes
-rw-r--r--tensorflow/contrib/lite/testdata/2_subgraphs.binbin0 -> 172 bytes
-rw-r--r--tensorflow/contrib/lite/testdata/empty_model.binbin0 -> 132 bytes
-rw-r--r--tensorflow/contrib/lite/testdata/multi_add.binbin0 -> 652 bytes
-rw-r--r--tensorflow/contrib/lite/testdata/multi_add.json46
-rw-r--r--tensorflow/contrib/lite/testdata/no_subgraphs.binbin0 -> 80 bytes
-rw-r--r--tensorflow/contrib/lite/testdata/test_model.binbin0 -> 496 bytes
-rw-r--r--tensorflow/contrib/lite/testdata/test_model_broken.binbin0 -> 432 bytes
-rw-r--r--tensorflow/contrib/lite/testdata/test_model_broken.json62
-rw-r--r--tensorflow/contrib/lite/testdata/two_subgraphs.binbin0 -> 172 bytes
-rw-r--r--tensorflow/contrib/lite/testing/BUILD213
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py1189
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples_report.py125
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc279
-rw-r--r--tensorflow/contrib/lite/testing/message.cc96
-rw-r--r--tensorflow/contrib/lite/testing/message.h82
-rw-r--r--tensorflow/contrib/lite/testing/message_test.cc121
-rw-r--r--tensorflow/contrib/lite/testing/nnapi_example.cc114
-rw-r--r--tensorflow/contrib/lite/testing/parse_testdata.cc335
-rw-r--r--tensorflow/contrib/lite/testing/parse_testdata.h74
-rw-r--r--tensorflow/contrib/lite/testing/split.cc42
-rw-r--r--tensorflow/contrib/lite/testing/split.h77
-rw-r--r--tensorflow/contrib/lite/testing/split_test.cc57
-rw-r--r--tensorflow/contrib/lite/testing/test_runner.h124
-rw-r--r--tensorflow/contrib/lite/testing/test_runner_test.cc84
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc208
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.h62
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver_test.cc61
-rw-r--r--tensorflow/contrib/lite/testing/tokenize.cc95
-rw-r--r--tensorflow/contrib/lite/testing/tokenize.h42
-rw-r--r--tensorflow/contrib/lite/testing/tokenize_test.cc105
-rw-r--r--tensorflow/contrib/lite/toco/BUILD350
-rw-r--r--tensorflow/contrib/lite/toco/allocate_transient_arrays.cc318
-rw-r--r--tensorflow/contrib/lite/toco/allocate_transient_arrays.h44
-rw-r--r--tensorflow/contrib/lite/toco/args.h225
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.cc293
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.h28
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc1570
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.h27
-rw-r--r--tensorflow/contrib/lite/toco/format_port.h77
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc98
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc69
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc223
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc56
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc42
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc57
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc98
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc300
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc326
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc108
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h186
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc229
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc170
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc106
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc396
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc103
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc120
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc142
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc1129
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc467
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc105
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc59
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc60
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc38
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc113
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc40
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc68
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc107
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h55
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc87
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc92
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc122
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc135
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc247
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc196
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc76
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc62
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc175
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc51
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc55
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc93
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc49
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc52
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc62
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc86
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc106
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc63
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc54
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc123
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc97
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD31
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc221
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc73
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc1508
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.h34
-rw-r--r--tensorflow/contrib/lite/toco/model.h1372
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc374
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.h43
-rw-r--r--tensorflow/contrib/lite/toco/model_flags.proto119
-rw-r--r--tensorflow/contrib/lite/toco/python/BUILD76
-rw-r--r--tensorflow/contrib/lite/toco/python/toco.i32
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_from_protos.py63
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_from_protos_test.py96
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_python_api.cc85
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_python_api.h33
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_wrapper.py35
-rw-r--r--tensorflow/contrib/lite/toco/runtime/common.h26
-rw-r--r--tensorflow/contrib/lite/toco/runtime/types.h32
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD102
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc52
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h101
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc34
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h33
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc151
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h63
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc285
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h82
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc212
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_util.cc197
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_util.h32
-rw-r--r--tensorflow/contrib/lite/toco/tflite/BUILD142
-rw-r--r--tensorflow/contrib/lite/toco/tflite/builtin_operator.h74
-rw-r--r--tensorflow/contrib/lite/toco/tflite/custom_operator.h74
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc322
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h76
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc69
-rw-r--r--tensorflow/contrib/lite/toco/tflite/import.cc183
-rw-r--r--tensorflow/contrib/lite/toco/tflite/import.h49
-rw-r--r--tensorflow/contrib/lite/toco/tflite/import_test.cc141
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc627
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h89
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc370
-rw-r--r--tensorflow/contrib/lite/toco/tflite/simple_operator.h50
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types.cc165
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types.h58
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types_test.cc191
-rw-r--r--tensorflow/contrib/lite/toco/toco.cc119
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc206
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.h35
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto126
-rw-r--r--tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc22
-rw-r--r--tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h34
-rw-r--r--tensorflow/contrib/lite/toco/toco_port.cc227
-rw-r--r--tensorflow/contrib/lite/toco/toco_port.h80
-rw-r--r--tensorflow/contrib/lite/toco/toco_port_test.cc58
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc277
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.h50
-rw-r--r--tensorflow/contrib/lite/toco/toco_types.h45
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1552
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h292
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util_test.cc96
-rw-r--r--tensorflow/contrib/lite/tools/BUILD60
-rw-r--r--tensorflow/contrib/lite/tools/gen_op_registration.cc46
-rw-r--r--tensorflow/contrib/lite/tools/gen_op_registration.h38
-rw-r--r--tensorflow/contrib/lite/tools/gen_op_registration_main.cc98
-rw-r--r--tensorflow/contrib/lite/tools/gen_op_registration_test.cc87
-rw-r--r--tensorflow/contrib/lite/tools/mutable_op_resolver.cc43
-rw-r--r--tensorflow/contrib/lite/tools/mutable_op_resolver.h45
-rw-r--r--tensorflow/contrib/lite/version.h23
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt18
-rw-r--r--tensorflow/contrib/metrics/__init__.py2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py149
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py262
-rw-r--r--tensorflow/contrib/nccl/BUILD4
-rw-r--r--tensorflow/contrib/nccl/python/ops/nccl_ops_test.py7
-rw-r--r--tensorflow/contrib/quantize/BUILD18
-rw-r--r--tensorflow/contrib/quantize/README.md73
-rw-r--r--tensorflow/contrib/quantize/g3doc/drawings/Fake_Quantization.jpgbin0 -> 32990 bytes
-rw-r--r--tensorflow/contrib/quantize/python/quant_ops.py57
-rw-r--r--tensorflow/contrib/quantize/python/quant_ops_test.py87
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py6
-rw-r--r--tensorflow/contrib/quantize/python/quantize_parameterized_test.py65
-rw-r--r--tensorflow/contrib/quantize/python/quantize_test.py25
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py37
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py6
-rw-r--r--tensorflow/contrib/slim/BUILD2
-rw-r--r--tensorflow/contrib/slim/python/slim/evaluation.py15
-rw-r--r--tensorflow/contrib/slim/python/slim/evaluation_test.py46
-rw-r--r--tensorflow/contrib/summary/BUILD6
-rw-r--r--tensorflow/contrib/summary/summary.py2
-rw-r--r--tensorflow/contrib/summary/summary_ops.py166
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py122
-rw-r--r--tensorflow/contrib/tensorboard/db/BUILD2
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer.cc34
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc56
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config.py14
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py144
-rw-r--r--tensorflow/contrib/training/python/training/hparam.py58
-rw-r--r--tensorflow/contrib/training/python/training/hparam_test.py31
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.cc13
-rw-r--r--tensorflow/core/framework/bfloat16.cc30
-rw-r--r--tensorflow/core/framework/bfloat16_test.cc92
-rw-r--r--tensorflow/core/framework/numeric_types.h251
-rw-r--r--tensorflow/core/framework/op_gen_lib.cc4
-rw-r--r--tensorflow/core/framework/register_types.h5
-rw-r--r--tensorflow/core/framework/rendezvous.cc18
-rw-r--r--tensorflow/core/graph/graph_constructor.cc10
-rw-r--r--tensorflow/core/graph/graph_constructor.h3
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc15
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc111
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc32
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc29
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc171
-rw-r--r--tensorflow/core/grappler/grappler_item.cc15
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc33
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc140
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc172
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h6
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc80
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc4
-rw-r--r--tensorflow/core/grappler/utils.cc1
-rw-r--r--tensorflow/core/kernels/BUILD28
-rw-r--r--tensorflow/core/kernels/batch_dataset_op.cc21
-rw-r--r--tensorflow/core/kernels/concat_lib_cpu.cc9
-rw-r--r--tensorflow/core/kernels/concatenate_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/dataset.cc1
-rw-r--r--tensorflow/core/kernels/dataset.h23
-rw-r--r--tensorflow/core/kernels/fake_quant_ops_functor.h15
-rw-r--r--tensorflow/core/kernels/immutable_constant_op_test.cc4
-rw-r--r--tensorflow/core/kernels/range_dataset_op.cc1
-rw-r--r--tensorflow/core/kernels/reader_dataset_ops.cc1
-rw-r--r--tensorflow/core/kernels/repeat_dataset_op.cc50
-rw-r--r--tensorflow/core/kernels/shuffle_dataset_op.cc31
-rw-r--r--tensorflow/core/kernels/skip_dataset_op.cc63
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc1
-rw-r--r--tensorflow/core/kernels/strided_slice_op_impl.h2
-rw-r--r--tensorflow/core/kernels/summary_interface.cc4
-rw-r--r--tensorflow/core/kernels/summary_kernels.cc50
-rw-r--r--tensorflow/core/kernels/take_dataset_op.cc59
-rw-r--r--tensorflow/core/kernels/zip_dataset_op.cc63
-rw-r--r--tensorflow/core/lib/core/stringpiece.h11
-rw-r--r--tensorflow/core/lib/io/block.cc2
-rw-r--r--tensorflow/core/lib/strings/str_util.cc4
-rw-r--r--tensorflow/core/lib/strings/strcat.cc2
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt498
-rw-r--r--tensorflow/core/ops/dataset_ops.cc197
-rw-r--r--tensorflow/core/ops/logging_ops.cc2
-rw-r--r--tensorflow/core/ops/ops.pbtxt363
-rw-r--r--tensorflow/core/ops/summary_ops.cc41
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc2
-rw-r--r--tensorflow/core/platform/default/build_config.bzl21
-rw-r--r--tensorflow/core/util/bcast.cc2
-rw-r--r--tensorflow/core/util/device_name_utils.cc1
-rw-r--r--tensorflow/core/util/memmapped_file_system.cc5
-rw-r--r--tensorflow/core/util/semver_test.cc2
-rw-r--r--tensorflow/docs_src/mobile/index.md4
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md89
-rw-r--r--tensorflow/docs_src/programmers_guide/debugger.md26
-rw-r--r--tensorflow/examples/image_retraining/README.md12
-rw-r--r--tensorflow/examples/image_retraining/retrain.py82
-rw-r--r--tensorflow/examples/image_retraining/retrain_test.py23
-rw-r--r--tensorflow/examples/learn/iris.py5
-rw-r--r--tensorflow/examples/learn/wide_n_deep_tutorial.py5
-rw-r--r--tensorflow/go/op/wrappers.go467
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java43
-rw-r--r--tensorflow/python/client/session_clusterspec_prop_test.py2
-rw-r--r--tensorflow/python/client/tf_session.i10
-rw-r--r--tensorflow/python/client/timeline.py2
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py16
-rw-r--r--tensorflow/python/debug/wrappers/dumping_wrapper_test.py18
-rw-r--r--tensorflow/python/eager/BUILD7
-rw-r--r--tensorflow/python/eager/backprop.py20
-rw-r--r--tensorflow/python/eager/backprop_test.py68
-rw-r--r--tensorflow/python/eager/benchmarks_test.py49
-rw-r--r--tensorflow/python/eager/execute.py3
-rw-r--r--tensorflow/python/eager/function.py6
-rw-r--r--tensorflow/python/eager/function_test.py29
-rw-r--r--tensorflow/python/eager/graph_callable.py13
-rw-r--r--tensorflow/python/eager/imperative_grad.py196
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc8
-rw-r--r--tensorflow/python/eager/pywrap_tensor.h25
-rw-r--r--tensorflow/python/eager/pywrap_tfe.h13
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc358
-rw-r--r--tensorflow/python/eager/tape.py17
-rw-r--r--tensorflow/python/eager/tape_test.py20
-rw-r--r--tensorflow/python/estimator/BUILD63
-rw-r--r--tensorflow/python/estimator/canned/baseline.py349
-rw-r--r--tensorflow/python/estimator/canned/baseline_test.py1545
-rw-r--r--tensorflow/python/estimator/canned/head.py131
-rw-r--r--tensorflow/python/estimator/canned/head_test.py10
-rw-r--r--tensorflow/python/estimator/estimator.py2
-rw-r--r--tensorflow/python/estimator/estimator_lib.py4
-rw-r--r--tensorflow/python/estimator/estimator_test.py66
-rw-r--r--tensorflow/python/framework/function.py2
-rw-r--r--tensorflow/python/framework/function_test.py2
-rw-r--r--tensorflow/python/framework/ops.py105
-rw-r--r--tensorflow/python/framework/ops_test.py87
-rw-r--r--tensorflow/python/framework/test_ops.cc23
-rw-r--r--tensorflow/python/grappler/model_analyzer.cc9
-rw-r--r--tensorflow/python/keras/BUILD12
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/topology.py9
-rw-r--r--tensorflow/python/keras/_impl/keras/integration_test.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/gru_test.py12
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/lstm_test.py11
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent.py2383
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent_test.py378
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/simplernn_test.py12
-rw-r--r--tensorflow/python/keras/layers/__init__.py5
-rw-r--r--tensorflow/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py13
-rw-r--r--tensorflow/python/kernel_tests/check_ops_test.py311
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py5
-rw-r--r--tensorflow/python/kernel_tests/distributions/multinomial_test.py8
-rw-r--r--tensorflow/python/kernel_tests/gather_nd_op_test.py10
-rw-r--r--tensorflow/python/kernel_tests/iterator_ops_test.py72
-rw-r--r--tensorflow/python/kernel_tests/range_dataset_op_test.py330
-rw-r--r--tensorflow/python/kernel_tests/reader_dataset_ops_test.py298
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py12
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py18
-rw-r--r--tensorflow/python/layers/base.py17
-rw-r--r--tensorflow/python/layers/base_test.py9
-rw-r--r--tensorflow/python/lib/core/strings.i4
-rw-r--r--tensorflow/python/ops/array_grad.py6
-rw-r--r--tensorflow/python/ops/array_ops.py18
-rw-r--r--tensorflow/python/ops/check_ops.py79
-rw-r--r--tensorflow/python/ops/control_flow_ops.py45
-rw-r--r--tensorflow/python/ops/ctc_ops.py30
-rw-r--r--tensorflow/python/ops/embedding_ops.py9
-rw-r--r--tensorflow/python/ops/nn.py1
-rw-r--r--tensorflow/python/ops/nn_grad.py5
-rw-r--r--tensorflow/python/ops/nn_impl.py10
-rw-r--r--tensorflow/python/ops/nn_ops.py164
-rw-r--r--tensorflow/python/ops/rnn.py42
-rw-r--r--tensorflow/python/ops/variable_scope.py5
-rw-r--r--tensorflow/python/pywrap_dlopen_global_flags.py13
-rw-r--r--tensorflow/python/pywrap_tfe.i4
-rw-r--r--tensorflow/python/tools/inspect_checkpoint.py23
-rw-r--r--tensorflow/python/training/monitored_session.py2
-rw-r--r--tensorflow/python/util/tf_should_use.py2
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc18
-rw-r--r--tensorflow/tensorflow.bzl44
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt54
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt54
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt179
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt86
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt179
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt90
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt191
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt179
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt78
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt183
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt20
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.pbtxt4
-rwxr-xr-xtensorflow/tools/ci_build/ci_parameterized_build.sh32
-rwxr-xr-xtensorflow/tools/ci_build/ci_sanity.sh10
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh33
-rwxr-xr-xtensorflow/tools/ci_build/osx/cpu/run_contrib.sh2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel5
-rw-r--r--tensorflow/tools/pip_package/BUILD3
-rw-r--r--tensorflow/tools/pip_package/MANIFEST.in1
-rwxr-xr-xtensorflow/tools/pip_package/build_pip_package.sh3
-rw-r--r--tensorflow/tools/pip_package/setup.py3
-rw-r--r--tensorflow/workspace.bzl45
-rw-r--r--third_party/flatbuffers/flatbuffers.BUILD4
-rw-r--r--third_party/tflite_mobilenet.BUILD13
-rw-r--r--tools/bazel.rc8
760 files changed, 87003 insertions, 3722 deletions
diff --git a/configure.py b/configure.py
index 83ee01c630..0864b6e64b 100644
--- a/configure.py
+++ b/configure.py
@@ -229,17 +229,9 @@ def setup_python(environ_cp):
# Set-up env variables used by python_configure.bzl
write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
- write_to_bazelrc('build --define PYTHON_BIN_PATH="%s"' % python_bin_path)
- write_to_bazelrc('build --define PYTHON_LIB_PATH="%s"' % python_lib_path)
write_to_bazelrc('build --force_python=py%s' % python_major_version)
write_to_bazelrc('build --host_force_python=py%s' % python_major_version)
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
- write_to_bazelrc('test --force_python=py%s' % python_major_version)
- write_to_bazelrc('test --host_force_python=py%s' % python_major_version)
- write_to_bazelrc('test --define PYTHON_BIN_PATH="%s"' % python_bin_path)
- write_to_bazelrc('test --define PYTHON_LIB_PATH="%s"' % python_lib_path)
- write_to_bazelrc('run --define PYTHON_BIN_PATH="%s"' % python_bin_path)
- write_to_bazelrc('run --define PYTHON_LIB_PATH="%s"' % python_lib_path)
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
# Write tools/python_bin_path.sh
@@ -488,10 +480,14 @@ def set_cc_opt_flags(environ_cp):
cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS',
question, default_cc_opt_flags)
for opt in cc_opt_flags.split():
- host_opt = '-march=native' # It should be safe on the same build host.
- write_to_bazelrc(
- 'build:opt --cxxopt=%s --copt=%s' % (opt, opt) +
- ' --host_cxxopt=%s --host_copt=%s' % (host_opt, host_opt))
+ write_to_bazelrc('build:opt --copt=%s' % opt)
+ # It should be safe on the same build host.
+ write_to_bazelrc('build:opt --host_copt=-march=native')
+ write_to_bazelrc('build:opt --define with_default_optimizations=true')
+ # TODO(mikecase): Remove these default defines once we are able to get
+ # TF Lite targets building without them.
+ write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK')
+ write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK')
def set_tf_cuda_clang(environ_cp):
@@ -968,7 +964,6 @@ def set_other_mpi_vars(environ_cp):
def set_mkl():
write_to_bazelrc('build:mkl --define using_mkl=true')
write_to_bazelrc('build:mkl -c opt')
- write_to_bazelrc('build:mkl --copt="-DEIGEN_USE_VML"')
print(
'Add "--config=mkl" to your bazel command to build with MKL '
'support.\nPlease note that MKL on MacOS or windows is still not '
@@ -1023,7 +1018,6 @@ def main():
environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
environ_cp['TF_NEED_COMPUTECPP'] = '0'
environ_cp['TF_NEED_OPENCL'] = '0'
- environ_cp['TF_NEED_S3'] = '0'
environ_cp['TF_CUDA_CLANG'] = '0'
if is_macos():
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 9874f95ea3..54688e84d1 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -119,7 +119,7 @@ config_setting(
config_setting(
name = "no_tensorflow_py_deps",
- values = {"define": "no_tensorflow_py_deps=true"},
+ define_values = {"no_tensorflow_py_deps": "true"},
visibility = ["//visibility:public"],
)
@@ -175,55 +175,122 @@ config_setting(
# TODO(jhseu): Enable on other platforms other than Linux.
config_setting(
name = "with_jemalloc_linux_x86_64",
- values = {
- "cpu": "k8",
- "define": "with_jemalloc=true",
- },
+ define_values = {"with_jemalloc": "true"},
+ values = {"cpu": "k8"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_jemalloc_linux_ppc64le",
- values = {
- "cpu": "ppc",
- "define": "with_jemalloc=true",
- },
+ define_values = {"with_jemalloc": "true"},
+ values = {"cpu": "ppc"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "with_default_optimizations",
+ define_values = {"with_default_optimizations": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_gcp_support",
- values = {"define": "with_gcp_support=true"},
+ define_values = {"with_gcp_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_hdfs_support",
- values = {"define": "with_hdfs_support=true"},
+ define_values = {"with_hdfs_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_s3_support",
- values = {"define": "with_s3_support=true"},
+ define_values = {"with_s3_support": "true"},
+ visibility = ["//visibility:public"],
+)
+
+# Crosses between platforms and file system libraries not supported on those
+# platforms due to limitations in nested select() statements.
+config_setting(
+ name = "with_gcp_support_windows_override",
+ define_values = {"with_gcp_support": "true"},
+ values = {"cpu": "x64_windows"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "with_hdfs_support_windows_override",
+ define_values = {"with_hdfs_support": "true"},
+ values = {"cpu": "x64_windows"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "with_s3_support_windows_override",
+ define_values = {"with_s3_support": "true"},
+ values = {"cpu": "x64_windows"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "with_gcp_support_android_override",
+ define_values = {"with_gcp_support": "true"},
+ values = {"crosstool_top": "//external:android/crosstool"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "with_hdfs_support_android_override",
+ define_values = {"with_hdfs_support": "true"},
+ values = {"crosstool_top": "//external:android/crosstool"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "with_s3_support_android_override",
+ define_values = {"with_s3_support": "true"},
+ values = {"crosstool_top": "//external:android/crosstool"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "with_gcp_support_ios_override",
+ define_values = {"with_gcp_support": "true"},
+ values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "with_hdfs_support_ios_override",
+ define_values = {"with_hdfs_support": "true"},
+ values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "with_s3_support_ios_override",
+ define_values = {"with_s3_support": "true"},
+ values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_xla_support",
- values = {"define": "with_xla_support=true"},
+ define_values = {"with_xla_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_gdr_support",
- values = {"define": "with_gdr_support=true"},
+ define_values = {"with_gdr_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_verbs_support",
- values = {"define": "with_verbs_support=true"},
+ define_values = {"with_verbs_support": "true"},
visibility = ["//visibility:public"],
)
@@ -297,7 +364,7 @@ config_setting(
visibility = ["//visibility:public"],
)
-# Make a dummy rule that we can chaqnge "default" in select statements to.
+# Make a dummy rule that we can change "default" in select statements to.
# to disable dependencies in copybara.
config_setting(
name = "dummy_disabled_internal",
@@ -353,6 +420,7 @@ filegroup(
"//tensorflow/compiler/tf2xla:all_files",
"//tensorflow/compiler/tf2xla/cc:all_files",
"//tensorflow/compiler/tf2xla/kernels:all_files",
+ "//tensorflow/compiler/tf2xla/lib:all_files",
"//tensorflow/compiler/tf2xla/ops:all_files",
"//tensorflow/compiler/xla:all_files",
"//tensorflow/compiler/xla/client:all_files",
@@ -425,6 +493,25 @@ filegroup(
"//tensorflow/contrib/learn/python/learn/datasets:all_files",
"//tensorflow/contrib/linalg:all_files",
"//tensorflow/contrib/linear_optimizer:all_files",
+ "//tensorflow/contrib/lite:all_files",
+ "//tensorflow/contrib/lite/java:all_files",
+ "//tensorflow/contrib/lite/java/demo/app/src/main:all_files",
+ "//tensorflow/contrib/lite/java/demo/app/src/main/assets:all_files",
+ "//tensorflow/contrib/lite/java/src/main/native:all_files",
+ "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:all_files",
+ "//tensorflow/contrib/lite/kernels:all_files",
+ "//tensorflow/contrib/lite/kernels/internal:all_files",
+ "//tensorflow/contrib/lite/models/smartreply:all_files",
+ "//tensorflow/contrib/lite/nnapi:all_files",
+ "//tensorflow/contrib/lite/python:all_files",
+ "//tensorflow/contrib/lite/schema:all_files",
+ "//tensorflow/contrib/lite/testing:all_files",
+ "//tensorflow/contrib/lite/toco:all_files",
+ "//tensorflow/contrib/lite/toco/graph_transformations/tests:all_files",
+ "//tensorflow/contrib/lite/toco/python:all_files",
+ "//tensorflow/contrib/lite/toco/tensorflow_graph_matching:all_files",
+ "//tensorflow/contrib/lite/toco/tflite:all_files",
+ "//tensorflow/contrib/lite/tools:all_files",
"//tensorflow/contrib/lookup:all_files",
"//tensorflow/contrib/losses:all_files",
"//tensorflow/contrib/makefile:all_files",
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 6dd1b99910..dd638de3c6 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -890,8 +890,8 @@ const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
TF_Status* status) {
const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
if (attr == nullptr) {
- status->status =
- InvalidArgument("Operation has no attr named '", attr_name, "'.");
+ status->status = InvalidArgument("Operation '", oper->node.name(),
+ "' has no attr named '", attr_name, "'.");
}
return attr;
}
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 05881e619b..e0057eb51c 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -383,7 +383,7 @@ TEST(CAPI, Graph) {
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s));
- EXPECT_EQ(string("Operation has no attr named 'missing'."),
+ EXPECT_EQ(string("Operation 'feed' has no attr named 'missing'."),
string(TF_Message(s)));
// Make a constant oper with the scalar "3".
@@ -1054,7 +1054,7 @@ class CApiColocationTest : public ::testing::Test {
TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_);
if (expected.empty()) {
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
- EXPECT_EQ(std::string("Operation has no attr named '_class'."),
+ EXPECT_EQ(std::string("Operation 'add' has no attr named '_class'."),
std::string(TF_Message(s_)));
return;
}
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index c77896b80b..d533758e36 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -39,6 +39,7 @@ tf_cuda_library(
tf_cuda_library(
name = "c_api_internal",
hdrs = ["c_api_internal.h"],
+ visibility = ["//tensorflow:internal"],
deps = [
":c_api",
":runtime",
@@ -105,7 +106,6 @@ tf_cc_test(
cc_library(
name = "tape",
- srcs = ["tape.cc"],
hdrs = ["tape.h"],
visibility = ["//tensorflow:internal"],
deps = [
diff --git a/tensorflow/c/eager/tape.cc b/tensorflow/c/eager/tape.cc
deleted file mode 100644
index 464612a81e..0000000000
--- a/tensorflow/c/eager/tape.cc
+++ /dev/null
@@ -1,102 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/c/eager/tape.h"
-
-namespace tensorflow {
-namespace eager {
-
-bool GradientTape::ShouldRecord(gtl::ArraySlice<int64> tensor_ids) {
- for (int64 i : tensor_ids) {
- if (tensor_tape_.find(i) != tensor_tape_.end()) {
- return true;
- }
- }
- return false;
-}
-
-void GradientTape::Watch(int64 tensor_id) {
- tensor_tape_.emplace(tensor_id, -1);
-}
-
-void GradientTape::RecordOperation(
- const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
- gtl::ArraySlice<int64> input_tensor_id, void* backward_function,
- const std::function<void()>& backward_function_deleter) {
- if (!ShouldRecord(input_tensor_id)) {
- backward_function_deleter();
- return;
- }
- std::vector<int64> ids;
- ids.reserve(input_tensor_id.size());
- for (int64 i : input_tensor_id) {
- tensor_usage_[i]++;
- ids.push_back(i);
- }
- const int64 op_id = next_op_id_++;
- std::vector<TapeTensor> tensors;
- tensors.reserve(output_tensors.size());
- for (const TapeTensor& o : output_tensors) {
- // Note: the tensor can have already been watched and hence be in the tape,
- // so we cannot check that we're inserting it here.
- tensor_tape_[o.id] = op_id;
- tensor_usage_[o.id] = 1;
- tensors.push_back(o);
- }
- op_tape_[op_id] = OpTapeEntry{op_type, tensors, ids, backward_function,
- backward_function_deleter};
-}
-
-void GradientTape::DeleteTrace(int64 tensor_id) {
- auto it = tensor_usage_.find(tensor_id);
- if (it == tensor_usage_.end()) {
- return;
- }
- it->second--;
- if (it->second != 0) {
- return;
- }
- tensor_usage_.erase(it);
- auto tensor_op_it = tensor_tape_.find(tensor_id);
- if (tensor_op_it == tensor_tape_.end()) {
- return;
- }
- const int64 op_id = tensor_op_it->second;
- if (op_id == -1) {
- // Do not delete watched tensors.
- return;
- }
- tensor_tape_.erase(tensor_op_it);
- auto op_it = op_tape_.find(op_id);
- CHECK(op_it != op_tape_.end());
- for (const auto& output : op_it->second.output_tensor_info) {
- if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
- // Found a usage for an output, so cannot delete the op.
- return;
- }
- }
- for (int64 id : op_it->second.input_tensor_id) {
- DeleteTrace(id);
- }
- op_it->second.backward_function_deleter();
- op_tape_.erase(op_it);
-}
-
-std::pair<TensorTape, OpTape> GradientTape::Export() {
- return {std::move(tensor_tape_), std::move(op_tape_)};
-}
-
-} // namespace eager
-} // namespace tensorflow
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index df51f300eb..29d73c5ca4 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -19,6 +19,7 @@ limitations under the License.
// maintains the data structures required to do so.
#include <unordered_map>
+#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
@@ -36,13 +37,14 @@ struct TapeTensor {
};
// Represents an entry in the tape.
+template <typename BackwardFunction>
struct OpTapeEntry {
string op_type;
std::vector<TapeTensor> output_tensor_info;
std::vector<int64> input_tensor_id;
// TODO(apassos) consider narrowing down this interface.
- void* backward_function;
+ BackwardFunction* backward_function;
// Should be called before deleting the backward function. TODO(apassos) use
// unique_ptrs to ensure this happens.
@@ -55,13 +57,68 @@ struct OpTapeEntry {
using TensorTape = std::unordered_map<int64, int64>;
// Map from operation-id to tape entry.
-using OpTape = std::unordered_map<int64, OpTapeEntry>;
+template <typename BackwardFunction>
+using OpTape = std::unordered_map<int64, OpTapeEntry<BackwardFunction>>;
+
+// Operations the tape needs to perform on tensors to do backpropagation. Named
+// "vspace" because a subset of these are related to a vector space, such as
+// adding gradients, getting zeroes, etc. Currently cannot be implemented
+// without using tensorflow python code, hence left unspecified here.
+//
+// Gradient is the type returned by gradient functions. In Python TF it's either
+// Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need
+// to allow their size to be computed and they need to be passable to a backward
+// function and deleted (as the backprop code creates lots of gradients the user
+// is not interested in).
+//
+// BackwardFunction needs to be a closure which stores intermediate activations
+// from the forward computation and calls a vector-jacobian product function
+// (also known as adjoint function) to compute, given downstream gradients,
+// upstream gradients.
+//
+// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
+// specialization, which is blocked by quite a few things needing to loop back
+// into python now.
+template <typename Gradient, typename BackwardFunction>
+class VSpace {
+ public:
+ virtual ~VSpace() {}
+
+ // Returns the number of elements in the gradient tensor.
+ virtual int64 NumElements(Gradient* tensor) const = 0;
+
+ // Consumes references to the tensors in the gradient_tensors list and returns
+ // a tensor with the result.
+ virtual Gradient* AggregateGradients(
+ gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
+
+ // Returns a tensor of the right shape and dtype filled with zeros.
+ virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0;
+
+ // Returns a Tensor which is filled with ones and like the input.
+ virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0;
+
+ // Calls the passed-in backward function.
+ virtual Status CallBackwardFunction(
+ BackwardFunction* backward_function,
+ gtl::ArraySlice<Gradient*> output_gradients,
+ std::vector<Gradient*>* result) const = 0;
+
+ // Deletes the input tensor.
+ virtual void DeleteGradient(Gradient* gradient) const = 0;
+};
// Traces the execution of operations, doing eager garbage collection, and
// exporting a full trace so other code can do backpropagation. Not thread-safe.
+template <typename Gradient, typename BackwardFunction>
class GradientTape {
public:
GradientTape() {}
+ ~GradientTape() {
+ for (const auto& pair : op_tape_) {
+ pair.second.backward_function_deleter();
+ }
+ }
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids);
@@ -70,19 +127,24 @@ class GradientTape {
void RecordOperation(const string& op_type,
gtl::ArraySlice<TapeTensor> output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
- void* backward_function,
+ BackwardFunction* backward_function,
const std::function<void()>& backward_function_deleter);
void DeleteTrace(int64 tensor_id);
- // Note: it is only valid to call Export once per tape, and after calling
- // export the tape is no longer valid (i.e. calls to ShouldRecord, Watch,
- // Record, and Delete have undefined behavior).
- std::pair<TensorTape, OpTape> Export();
+ // Consumes the internal state of the tape (so cannot be called more than
+ // once) and produces the gradient of the target tensors with respect to the
+ // source tensors. The output gradients are used if not empty and not
+ // null. The result is populated with one tensor per target element.
+ Status ComputeGradient(const VSpace<Gradient, BackwardFunction>& vspace,
+ gtl::ArraySlice<int64> target_tensor_ids,
+ gtl::ArraySlice<int64> source_tensor_id,
+ gtl::ArraySlice<Gradient*> output_gradients,
+ std::vector<Gradient*>* result);
private:
TensorTape tensor_tape_;
- OpTape op_tape_;
+ OpTape<BackwardFunction> op_tape_;
int64 next_op_id_{0};
// Map from tensor id to number of remaining usages (i.e. how many entries in
@@ -90,6 +152,429 @@ class GradientTape {
std::unordered_map<int64, int64> tensor_usage_;
};
+// Template instantiations here
+
+template <typename Gradient, typename BackwardFunction>
+bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
+ gtl::ArraySlice<int64> tensor_ids) {
+ for (int64 i : tensor_ids) {
+ if (tensor_tape_.find(i) != tensor_tape_.end()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+template <typename Gradient, typename BackwardFunction>
+void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
+ tensor_tape_.emplace(tensor_id, -1);
+}
+
+template <typename Gradient, typename BackwardFunction>
+void GradientTape<Gradient, BackwardFunction>::RecordOperation(
+ const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+ gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
+ const std::function<void()>& backward_function_deleter) {
+ if (!ShouldRecord(input_tensor_id)) {
+ backward_function_deleter();
+ return;
+ }
+ std::vector<int64> ids;
+ ids.reserve(input_tensor_id.size());
+ for (int64 i : input_tensor_id) {
+ tensor_usage_[i]++;
+ ids.push_back(i);
+ }
+ const int64 op_id = next_op_id_++;
+ std::vector<TapeTensor> tensors;
+ tensors.reserve(output_tensors.size());
+ for (const TapeTensor& o : output_tensors) {
+ // Note: the tensor can have already been watched and hence be in the tape,
+ // so we cannot check that we're inserting it here.
+ tensor_tape_[o.id] = op_id;
+ tensor_usage_[o.id] = 1;
+ tensors.push_back(o);
+ }
+ op_tape_[op_id] = OpTapeEntry<BackwardFunction>{
+ op_type, tensors, ids, backward_function, backward_function_deleter};
+}
+
+template <typename Gradient, typename BackwardFunction>
+void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
+ auto it = tensor_usage_.find(tensor_id);
+ if (it == tensor_usage_.end()) {
+ return;
+ }
+ it->second--;
+ if (it->second != 0) {
+ return;
+ }
+ tensor_usage_.erase(it);
+ auto tensor_op_it = tensor_tape_.find(tensor_id);
+ if (tensor_op_it == tensor_tape_.end()) {
+ return;
+ }
+ const int64 op_id = tensor_op_it->second;
+ if (op_id == -1) {
+ // Do not delete watched tensors.
+ return;
+ }
+ tensor_tape_.erase(tensor_op_it);
+ auto op_it = op_tape_.find(op_id);
+ CHECK(op_it != op_tape_.end());
+ for (const auto& output : op_it->second.output_tensor_info) {
+ if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
+ // Found a usage for an output, so cannot delete the op.
+ return;
+ }
+ }
+ for (int64 id : op_it->second.input_tensor_id) {
+ DeleteTrace(id);
+ }
+ op_it->second.backward_function_deleter();
+ op_tape_.erase(op_it);
+}
+
+// Terminology:
+//
+// - op: a possibly composite operation, which has an entry in the tape
+// - target: dy in dx/dy
+// - source: dx in dx/dy
+// - tensor: one of the many inputs or outputs of an operation
+//
+// Below here we do the gradient algorithm. It works as follows:
+//
+// First we filter the tape to just the subset of operations we want to
+// differentiate. In the process of doing so we count how many times each Tensor
+// is used as an input to an op (so we know when we're done computing gradients
+// for that Tensor). We also count, for each tape entry, how many of its output
+// Tensors need gradients to be computed (Tensors which are not used do not need
+// any gradients to be computed).
+//
+// Finally, we start a backprop stack with a set of tape entries for which we
+// have all gradients available. This set usually is a subset of the set of
+// targets (not all since targets which have outputs in the tape will not have
+// gradients available initially).
+//
+// Then we repeatedly pop an entry from the stack, run its backprop, and update
+// the gradients of its inputs. Once we have computed all gradients for a single
+// input we can mark this input as done, and this can trigger adding an entry to
+// the stack if all outputs of that entry are now done.
+//
+// When the stack is empty we have gradients for all tensors we're interested
+// in.
+
+namespace {
+
+template <typename BackwardFunction>
+struct BackpropInitialState {
+ OpTape<BackwardFunction> op_tape;
+
+ // Map from tensor ID to how many references still exist for this tensor in
+ // the tape.
+ std::unordered_map<int64, int64> tensor_usage_counts;
+
+ // Maps from op ID to how many output tensors of this op still need to have
+ // their gradients computed.
+ std::unordered_map<int64, int64> op_missing_tensor;
+};
+
+template <typename BackwardFunction>
+BackpropInitialState<BackwardFunction> PrepareBackprop(
+ gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
+ OpTape<BackwardFunction> op_tape,
+ const std::unordered_set<int64>& sources_set) {
+ std::vector<int64> tensor_stack;
+ tensor_stack.reserve(target.size());
+ for (auto t : target) {
+ tensor_stack.push_back(t);
+ }
+ BackpropInitialState<BackwardFunction> result;
+ while (!tensor_stack.empty()) {
+ int64 tensor_id = tensor_stack.back();
+ tensor_stack.pop_back();
+ auto op_id_it = tensor_tape.find(tensor_id);
+ if (op_id_it == tensor_tape.end()) {
+ continue;
+ }
+ int64 op_id = op_id_it->second;
+ auto op_it = op_tape.find(op_id);
+ auto result_op_it = result.op_tape.find(op_id);
+ if (op_id == -1 || op_it == op_tape.end() ||
+ result_op_it != result.op_tape.end()) {
+ continue;
+ }
+ CHECK(result.op_tape.emplace(op_id, op_it->second).second);
+ for (auto it : op_it->second.input_tensor_id) {
+ auto count_it = result.tensor_usage_counts.find(it);
+ if (count_it != result.tensor_usage_counts.end()) {
+ count_it->second++;
+ } else {
+ result.tensor_usage_counts[it] = 1;
+ if (sources_set.find(it) == sources_set.end() &&
+ tensor_tape.find(it) != tensor_tape.end()) {
+ tensor_stack.push_back(it);
+ }
+ }
+ }
+ op_tape.erase(op_it);
+ }
+ for (auto& pair : result.tensor_usage_counts) {
+ auto it = tensor_tape.find(pair.first);
+ if (it != tensor_tape.end() && it->second != -1) {
+ result.op_missing_tensor[it->second] += 1;
+ }
+ }
+ // Call destructors for all unneeded gradient functions.
+ for (const auto& op_pair : op_tape) {
+ op_pair.second.backward_function_deleter();
+ }
+ return result;
+}
+
+template <typename BackwardFunction>
+std::vector<int64> InitialStack(
+ const OpTape<BackwardFunction>& op_tape,
+ const std::unordered_map<int64, int64>& op_missing_tensor) {
+ std::vector<int64> result;
+ for (auto& op_entry : op_tape) {
+ if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
+ result.push_back(op_entry.first);
+ }
+ }
+ return result;
+}
+
+template <typename Gradient, typename BackwardFunction>
+Status InitialGradients(
+ const VSpace<Gradient, BackwardFunction>& vspace,
+ gtl::ArraySlice<int64> target_tensor_ids,
+ gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
+ const OpTape<BackwardFunction>& op_tape,
+ const std::unordered_map<int64, int64>& tensor_usage_counts,
+ std::unordered_map<int64, std::vector<Gradient*>>* result) {
+ for (int i = 0; i < target_tensor_ids.size(); ++i) {
+ const int64 id = target_tensor_ids[i];
+ if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
+ if (!output_gradients.empty() && output_gradients[i] != nullptr) {
+ // TODO(apassos) figure out how to print debugging information here.
+ return errors::InvalidArgument(
+ "A gradient was provided for a tensor which is used as part of the "
+ "computation.");
+ }
+ } else {
+ if (output_gradients.empty() || output_gradients[i] == nullptr) {
+ auto tensor_it = tensor_tape.find(id);
+ if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
+ auto op_it = op_tape.find(tensor_it->second);
+ if (op_it == op_tape.end()) {
+ return errors::Internal(
+ "Internal state of the gradient tape is invalid.");
+ }
+ bool found = false;
+ for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
+ if (op_it->second.output_tensor_info[j].id == id) {
+ found = true;
+ (*result)[id].push_back(
+ vspace.Ones(op_it->second.output_tensor_info[j].shape,
+ op_it->second.output_tensor_info[j].dtype));
+ break;
+ }
+ }
+ if (!found) {
+ return errors::Internal(
+ "Internal state of the gradient tape is invalid.");
+ }
+ } else {
+ // No record of the target tensor found on the tape, so no gradient
+ // needs to be computed from it. Do nothing.
+ }
+ } else {
+ (*result)[id].push_back(output_gradients[i]);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+// If over kMinAggregateCount gradients are accumulated and the total
+// memory consumption is over kMinAggregateBytes, do an early aggregation
+// so as to release the gradient tensor to save memory.
+constexpr int kMinAggregateCount = 4;
+constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
+
+template <typename Gradient, typename BackwardFunction>
+Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
+ const VSpace<Gradient, BackwardFunction>& vspace,
+ gtl::ArraySlice<int64> target_tensor_ids,
+ gtl::ArraySlice<int64> source_tensor_ids,
+ gtl::ArraySlice<Gradient*> output_gradients,
+ std::vector<Gradient*>* result) {
+ std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
+ source_tensor_ids.end());
+ BackpropInitialState<BackwardFunction> state = PrepareBackprop(
+ target_tensor_ids, tensor_tape_, std::move(op_tape_), sources_set);
+ std::vector<int64> op_stack =
+ InitialStack(state.op_tape, state.op_missing_tensor);
+ std::unordered_map<int64, std::vector<Gradient*>> gradients;
+ Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
+ tensor_tape_, state.op_tape,
+ state.tensor_usage_counts, &gradients);
+ auto cleanup = [&state]() {
+ // Release all backprop functions
+ for (const auto& pair : state.op_tape) {
+ pair.second.backward_function_deleter();
+ }
+ };
+ if (!s.ok()) {
+ cleanup();
+ return s;
+ }
+ std::unordered_map<int64, int64> gradients_size;
+ // TODO(apassos) multiple threads could be dequeuing from op_stack at the same
+ // time, for better CPU backprop performance.
+ VLOG(1) << "Initial stack:";
+ if (VLOG_IS_ON(1)) {
+ for (auto t : op_stack) {
+ VLOG(1) << " " << t;
+ }
+ }
+ std::unordered_map<string, std::unordered_set<int>>
+ functions_accept_none_for_indices({
+ {"SoftmaxCrossEntropyWithLogits", {1}},
+ {"FusedBatchNorm", {1, 2, 3, 4}},
+ });
+ while (!op_stack.empty()) {
+ const int64 op = op_stack.back();
+ VLOG(1) << "Popped " << op;
+ op_stack.pop_back();
+ auto op_it = state.op_tape.find(op);
+ if (op_it == state.op_tape.end()) {
+ // It is possible for ops to end up on the stack if they are unrelated to
+ // the target; we should just skip them.
+ continue;
+ }
+ auto trace = std::move(op_it->second);
+ state.op_tape.erase(op_it);
+ std::vector<Gradient*> out_gradients;
+ out_gradients.reserve(trace.output_tensor_info.size());
+ for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
+ const int64 id = trace.output_tensor_info[i].id;
+ auto grad_it = gradients.find(id);
+ if (grad_it == gradients.end()) {
+ auto func_name_it =
+ functions_accept_none_for_indices.find(trace.op_type);
+ if (func_name_it != functions_accept_none_for_indices.end() &&
+ func_name_it->second.find(i) != func_name_it->second.end()) {
+ out_gradients.push_back(nullptr);
+ } else {
+ out_gradients.push_back(
+ vspace.Zeros(trace.output_tensor_info[i].shape,
+ trace.output_tensor_info[i].dtype));
+ }
+ } else {
+ out_gradients.push_back(vspace.AggregateGradients(grad_it->second));
+ if (sources_set.find(grad_it->first) == sources_set.end()) {
+ gradients.erase(grad_it);
+ }
+ }
+ }
+ std::vector<Gradient*> in_gradients;
+ Status s = vspace.CallBackwardFunction(trace.backward_function,
+ out_gradients, &in_gradients);
+ if (!s.ok()) {
+ VLOG(1) << "Gradient function failed.";
+ cleanup();
+ return s;
+ }
+ VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
+ << trace.input_tensor_id.size() << " sources";
+ for (int i = 0; i < in_gradients.size(); ++i) {
+ const int64 id = trace.input_tensor_id[i];
+ if (in_gradients[i] != nullptr) {
+ auto& unaggregated_grads = gradients[id];
+ unaggregated_grads.push_back(in_gradients[i]);
+ if (unaggregated_grads.size() > kMinAggregateCount) {
+ auto size_it = gradients_size.find(id);
+ int64 size;
+ if (size_it == gradients_size.end()) {
+ size = vspace.NumElements(unaggregated_grads[0]);
+ gradients_size.emplace(id, size);
+ } else {
+ size = size_it->second;
+ }
+ if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
+ Gradient* grad = vspace.AggregateGradients(unaggregated_grads);
+ unaggregated_grads.clear();
+ unaggregated_grads.push_back(grad);
+ }
+ }
+ }
+ auto usage_count_it = state.tensor_usage_counts.find(id);
+ if (usage_count_it == state.tensor_usage_counts.end()) {
+ VLOG(1) << "Tensor " << id << " not used";
+ continue;
+ }
+ usage_count_it->second--;
+ if (usage_count_it->second > 0) {
+ VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
+ continue;
+ }
+ auto tape_it = tensor_tape_.find(id);
+ if (tape_it == tensor_tape_.end()) {
+ VLOG(1) << "Tensor " << id
+ << " has no associated op. Deleting gradient";
+ auto grad_it = gradients.find(id);
+ if (grad_it != gradients.end()) {
+ for (auto g : grad_it->second) {
+ vspace.DeleteGradient(g);
+ }
+ gradients.erase(grad_it);
+ }
+ continue;
+ }
+ const int64 op_id = tape_it->second;
+ if (op_id == -1) {
+ VLOG(1) << "Tensor " << id << " is source";
+ continue;
+ }
+ auto missing_it = state.op_missing_tensor.find(op_id);
+ if (missing_it != state.op_missing_tensor.end()) {
+ missing_it->second--;
+ VLOG(1) << "Op " << op_id << " missing " << missing_it->second
+ << " output gradients";
+ if (missing_it->second == 0) {
+ op_stack.push_back(op_id);
+ }
+ }
+ }
+ }
+ CHECK(state.op_tape.empty());
+ result->reserve(source_tensor_ids.size());
+ for (auto is : source_tensor_ids) {
+ auto grad_it = gradients.find(is);
+ if (grad_it == gradients.end()) {
+ result->push_back(nullptr);
+ } else {
+ if (grad_it->second.size() == 1) {
+ result->push_back(grad_it->second[0]);
+ } else {
+ result->push_back(vspace.AggregateGradients(grad_it->second));
+ }
+ gradients.erase(grad_it);
+ }
+ }
+ VLOG(1) << "Final gradients size: " << gradients.size();
+ for (auto grad_pair : gradients) {
+ for (const auto& g : grad_pair.second) {
+ vspace.DeleteGradient(g);
+ }
+ }
+ return Status::OK();
+}
+
} // namespace eager
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 27c5da08c1..e481796d9e 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -257,7 +257,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
- options.local_executable_has_hybrid_result = true;
const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable;
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 23368b6c76..bc2eccd277 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -227,10 +227,7 @@ Status XlaCompilationCache::BuildExecutable(
}
xla::ExecutableBuildOptions build_options;
build_options.set_device_ordinal(client_->default_device_ordinal());
- build_options.set_platform(client_->platform());
build_options.set_result_layout(result.xla_output_shape);
- build_options.set_has_hybrid_result(
- options.local_executable_has_hybrid_result);
auto compile_result =
client_->Compile(*result.computation, argument_layouts, build_options);
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 284ecbf97d..79c4befd36 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -130,6 +130,21 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "cholesky_op_test",
+ size = "small",
+ srcs = ["cholesky_op_test.py"],
+ tags = ["optonly"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
name = "clustering_test",
size = "small",
srcs = ["clustering_test.py"],
diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py
new file mode 100644
index 0000000000..5010fe5e21
--- /dev/null
+++ b/tensorflow/compiler/tests/cholesky_op_test.py
@@ -0,0 +1,126 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.tf.Cholesky."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class CholeskyOpTest(XLATestCase):
+
+ def _verifyCholeskyBase(self, sess, placeholder, x, chol, verification, atol):
+ chol_np, verification_np = sess.run([chol, verification], {placeholder: x})
+ self.assertAllClose(x, verification_np, atol=atol)
+ self.assertShapeEqual(x, chol)
+ # Check that the cholesky is lower triangular, and has positive diagonal
+ # elements.
+ if chol_np.shape[-1] > 0:
+ chol_reshaped = np.reshape(chol_np, (-1, chol_np.shape[-2],
+ chol_np.shape[-1]))
+ for chol_matrix in chol_reshaped:
+ self.assertAllClose(chol_matrix, np.tril(chol_matrix), atol=atol)
+ self.assertTrue((np.diag(chol_matrix) > 0.0).all())
+
+ def _verifyCholesky(self, x, atol=1e-6):
+ # Verify that LL^T == x.
+ with self.test_session() as sess:
+ placeholder = array_ops.placeholder(
+ dtypes.as_dtype(x.dtype), shape=x.shape)
+ with self.test_scope():
+ chol = linalg_ops.cholesky(placeholder)
+ verification = math_ops.matmul(chol, chol, adjoint_b=True)
+ self._verifyCholeskyBase(sess, placeholder, x, chol, verification, atol)
+
+ def testBasic(self):
+ data = np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]])
+ for dtype in self.float_types:
+ self._verifyCholesky(data.astype(dtype))
+
+ def testBatch(self):
+ for dtype in self.float_types:
+ simple_array = np.array(
+ [[[1., 0.], [0., 5.]]], dtype=dtype) # shape (1, 2, 2)
+ self._verifyCholesky(simple_array)
+ self._verifyCholesky(np.vstack((simple_array, simple_array)))
+ odd_sized_array = np.array(
+ [[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]], dtype=dtype)
+ self._verifyCholesky(np.vstack((odd_sized_array, odd_sized_array)))
+
+ # Generate random positive-definite matrices.
+ matrices = np.random.rand(10, 5, 5).astype(dtype)
+ for i in xrange(10):
+ matrices[i] = np.dot(matrices[i].T, matrices[i])
+ self._verifyCholesky(matrices, atol=1e-4)
+
+ def testNonSquareMatrix(self):
+ for dtype in self.float_types:
+ with self.assertRaises(ValueError):
+ linalg_ops.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]], dtype=dtype))
+ with self.assertRaises(ValueError):
+ linalg_ops.cholesky(
+ np.array(
+ [[[1., 2., 3.], [3., 4., 5.]], [[1., 2., 3.], [3., 4., 5.]]],
+ dtype=dtype))
+
+ def testWrongDimensions(self):
+ for dtype in self.float_types:
+ tensor3 = constant_op.constant([1., 2.], dtype=dtype)
+ with self.assertRaises(ValueError):
+ linalg_ops.cholesky(tensor3)
+ with self.assertRaises(ValueError):
+ linalg_ops.cholesky(tensor3)
+
+ @unittest.skip("Test is slow")
+ def testLarge(self):
+ n = 200
+ shape = (n, n)
+ data = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag(
+ np.ones(n).astype(np.float32))
+ self._verifyCholesky(data, atol=1e-4)
+
+ def testMatrixConditionNumbers(self):
+ for dtype in self.float_types:
+ condition_number = 1000
+ size = 20
+
+ # Generate random positive-definite symmetric matrices, and take their
+ # Eigendecomposition.
+ matrix = np.random.rand(size, size)
+ matrix = np.dot(matrix.T, matrix)
+ _, w = np.linalg.eigh(matrix)
+
+ # Build new Eigenvalues exponentially distributed between 1 and
+ # 1/condition_number
+ v = np.exp(-np.log(condition_number) * np.linspace(0, size, size) / size)
+ matrix = np.dot(np.dot(w, np.diag(v)), w.T).astype(dtype)
+ self._verifyCholesky(matrix, atol=1e-4)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 912e819d8d..376c8108ed 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -125,6 +125,7 @@ cc_library(
":functionalize_control_flow",
":sharding_util",
":tf2xla_util",
+ "//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 13d06177f0..948d7f0b40 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -19,6 +19,7 @@ tf_kernel_library(
"binary_ops.cc",
"cast_op.cc",
"categorical_op.cc",
+ "cholesky_op.cc",
"concat_op.cc",
"const_op.cc",
"conv_ops.cc",
@@ -81,6 +82,8 @@ tf_kernel_library(
":while_op",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/lib:batch_dot",
+ "//tensorflow/compiler/tf2xla/lib:cholesky",
"//tensorflow/compiler/tf2xla/ops:sendrecv_ops",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
@@ -91,6 +94,7 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:linalg_ops_op_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
"//tensorflow/core/kernels:concat_lib",
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
index 73ccc151c1..a015b8e0e8 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
@@ -13,11 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// XLA-specific BatchMatMul Op.
-// The current implementation simply unrolls the computation along the batch
-// dimension.
-// TODO(dominikg,phawkins): Use a real batched matmul instead of unrolling.
-
+#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -32,110 +28,10 @@ class BatchMatMulOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- const TensorShape x_shape = ctx->InputShape(0);
- const TensorShape y_shape = ctx->InputShape(1);
-
- // Check that both tensors have the same number of dimensions. There must be
- // at least two (the batch dimensions can be empty).
- OP_REQUIRES(ctx, x_shape.dims() == y_shape.dims(),
- errors::InvalidArgument("In[0] and In[1] has different ndims: ",
- x_shape.DebugString(), " vs. ",
- y_shape.DebugString()));
- const int ndims = x_shape.dims();
- OP_REQUIRES(
- ctx, ndims >= 2,
- errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
-
- // The batch dimensions must be equal and the matrix dimensions must be
- // valid.
- std::vector<int64> dimensions;
- int batch_count = 1;
- for (int i = 0; i < ndims - 2; ++i) {
- OP_REQUIRES(
- ctx, x_shape.dim_size(i) == y_shape.dim_size(i),
- errors::InvalidArgument("In[0].dim(", i, ") and In[1].dim(", i,
- ") must be the same: ", x_shape.DebugString(),
- " vs ", y_shape.DebugString()));
- dimensions.push_back(x_shape.dim_size(i));
- batch_count *= x_shape.dim_size(i);
- }
-
- int x_inner_dim = adj_x_ ? (ndims - 2) : (ndims - 1);
- int y_inner_dim = adj_y_ ? (ndims - 1) : (ndims - 2);
- OP_REQUIRES(
- ctx, x_shape.dim_size(x_inner_dim) == y_shape.dim_size(y_inner_dim),
- errors::InvalidArgument(
- "In[0] mismatch In[1] shape: ", x_shape.dim_size(x_inner_dim),
- " vs. ", y_shape.dim_size(y_inner_dim), ": ", x_shape.DebugString(),
- " ", y_shape.DebugString(), " ", adj_x_, " ", adj_y_));
-
- int x_outer_dim = adj_x_ ? (ndims - 1) : (ndims - 2);
- int y_outer_dim = adj_y_ ? (ndims - 2) : (ndims - 1);
- dimensions.push_back(x_shape.dim_size(x_outer_dim));
- dimensions.push_back(y_shape.dim_size(y_outer_dim));
-
- xla::ComputationBuilder* builder = ctx->builder();
-
- xla::ComputationDataHandle x_handle = ctx->Input(0);
- if (BaseType(input_type(0)) == DT_COMPLEX64 && adj_x_) {
- x_handle = builder->Conj(x_handle);
- }
- xla::ComputationDataHandle y_handle = ctx->Input(1);
- if (BaseType(input_type(1)) == DT_COMPLEX64 && adj_y_) {
- y_handle = builder->Conj(y_handle);
- }
-
- // Reshape input tensors into 3D tensors by flattening the batch
- // dimensions. This makes it easier to unroll the batch dimension.
- auto x_flat =
- builder->Reshape(x_handle, {batch_count, x_shape.dim_size(ndims - 2),
- x_shape.dim_size(ndims - 1)});
- auto y_flat =
- builder->Reshape(y_handle, {batch_count, y_shape.dim_size(ndims - 2),
- y_shape.dim_size(ndims - 1)});
-
- // Slice batches into individual matrices and multiply them.
- std::vector<xla::ComputationDataHandle> out_slices;
- for (int i = 0; i < batch_count; ++i) {
- // Slice off individual matrices and reshape to 2D tensors.
- auto x_slice = builder->Slice(
- x_flat, {i, 0, 0},
- {i + 1, x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)},
- {1, 1, 1});
- x_slice = builder->Reshape(
- x_slice, {x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)});
- auto y_slice = builder->Slice(
- y_flat, {i, 0, 0},
- {i + 1, y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)},
- {1, 1, 1});
- y_slice = builder->Reshape(
- y_slice, {y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)});
-
- // Transpose if needed.
- auto lhs = adj_x_ ? builder->Transpose(x_slice, {1, 0}) : x_slice;
- auto rhs = adj_y_ ? builder->Transpose(y_slice, {1, 0}) : y_slice;
-
- // Multiply matrices and add an outer singleton dimension to the output
- // so we can concatenate along the flattened batch dimension later.
- auto out = builder->Dot(lhs, rhs);
- out = builder->Reshape(out,
- {1, dimensions[ndims - 2], dimensions[ndims - 1]});
- out_slices.push_back(out);
- }
-
- // Concatenate output slices and reshape to original number of dimensions.
- xla::ComputationDataHandle data;
- if (out_slices.empty()) {
- // It is illegal to pass an empty list to ConcatInDim.
- // The batch count is empty, so both inputs must have zero elements.
- // Arbitrarily use the left input as the argument to Reshape().
- data = x_handle;
- } else {
- data = builder->ConcatInDim(out_slices, 0);
- }
- data = builder->Reshape(data, dimensions);
-
- ctx->SetOutput(0, data);
+ auto result =
+ BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), adj_x_, adj_y_);
+ OP_REQUIRES_OK(ctx, result.status());
+ ctx->SetOutput(0, result.ValueOrDie());
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc
new file mode 100644
index 0000000000..87d858f763
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/cholesky.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+namespace {
+
+class CholeskyOp : public XlaOpKernel {
+ public:
+ explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ void Compile(XlaOpKernelContext* ctx) override {
+ auto result = Cholesky(ctx->builder(), ctx->Input(0));
+ if (!result.ok()) {
+ ctx->SetStatus(result.status());
+ return;
+ }
+ ctx->SetOutput(0, result.ValueOrDie());
+ }
+};
+
+REGISTER_XLA_OP(Name("Cholesky"), CholeskyOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
new file mode 100644
index 0000000000..21ad21f737
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -0,0 +1,120 @@
+# Utilities for building XLA computations.
+
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = ["//tensorflow/compiler/tf2xla:friends"],
+)
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
+
+cc_library(
+ name = "batch_dot",
+ srcs = ["batch_dot.cc"],
+ hdrs = ["batch_dot.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "cholesky",
+ srcs = ["cholesky.cc"],
+ hdrs = ["cholesky.h"],
+ deps = [
+ ":batch_dot",
+ ":triangular_solve",
+ ":util",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "triangular_solve",
+ srcs = ["triangular_solve.cc"],
+ hdrs = ["triangular_solve.h"],
+ deps = [
+ ":batch_dot",
+ ":util",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/core:lib",
+ ],
+)
+
+xla_test(
+ name = "triangular_solve_test",
+ srcs = ["triangular_solve_test.cc"],
+ deps = [
+ ":triangular_solve",
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
+ name = "util",
+ srcs = ["util.cc"],
+ hdrs = ["util.h"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/core:lib",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
new file mode 100644
index 0000000000..28a5e6a58b
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -0,0 +1,154 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+// The current implementation simply unrolls the computation along the batch
+// dimension.
+// TODO(andydavis): add batching support to XLA's Dot operator.
+xla::StatusOr<xla::ComputationDataHandle> BatchDot(
+ xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
+ xla::ComputationDataHandle y, bool transpose_x, bool transpose_y) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> x_shape,
+ builder->GetShape(x));
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> y_shape,
+ builder->GetShape(y));
+
+ // Check that both tensors have the same number of dimensions. There must be
+ // at least two (the batch dimensions can be empty).
+ if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) {
+ return errors::InvalidArgument(
+ "Arguments to BatchedDot have different ranks: ",
+ xla::ShapeUtil::HumanString(*x_shape), " vs. ",
+ xla::ShapeUtil::HumanString(*y_shape));
+ }
+ const int ndims = xla::ShapeUtil::Rank(*x_shape);
+ if (ndims < 2) {
+ return errors::InvalidArgument(
+ "Arguments to BatchedDot must have rank >= 2: ", ndims);
+ }
+
+ // The batch dimensions must be equal and the matrix dimensions must be
+ // valid.
+ std::vector<int64> dimensions;
+ int64 batch_count = 1;
+ for (int i = 0; i < ndims - 2; ++i) {
+ int64 x_size = x_shape->dimensions(i);
+ int64 y_size = y_shape->dimensions(i);
+ if (x_size != y_size) {
+ return errors::InvalidArgument(
+ "Dimension ", i, " of inputs to BatchedDot must be equal: ",
+ xla::ShapeUtil::HumanString(*x_shape), " vs ",
+ xla::ShapeUtil::HumanString(*y_shape));
+ }
+ dimensions.push_back(x_size);
+ batch_count *= x_size;
+ }
+
+ int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1);
+ int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2);
+ int64 x_inner_dim_size = x_shape->dimensions(x_inner_dim);
+ int64 y_inner_dim_size = y_shape->dimensions(y_inner_dim);
+ if (x_inner_dim_size != y_inner_dim_size) {
+ return errors::InvalidArgument(
+ "Dimensions ", x_inner_dim, " and ", y_inner_dim,
+ " of arguments to BatchedDot must be equal: ",
+ xla::ShapeUtil::HumanString(*x_shape), " transpose: ", transpose_x,
+ " vs. ", xla::ShapeUtil::HumanString(*y_shape),
+ " transpose: ", transpose_y);
+ }
+
+ // If there are no batch dimensions, use a regular Dot. This case exists
+ // to improve the readability of the emitted graphs.
+ if (dimensions.empty()) {
+ auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x;
+ auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y;
+ return builder->Dot(lhs, rhs);
+ }
+
+ int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2);
+ int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1);
+ dimensions.push_back(x_shape->dimensions(x_outer_dim));
+ dimensions.push_back(y_shape->dimensions(y_outer_dim));
+
+ if (x_shape->element_type() == xla::C64 && transpose_x) {
+ x = builder->Conj(x);
+ }
+ if (y_shape->element_type() == xla::C64 && transpose_y) {
+ y = builder->Conj(y);
+ }
+
+ // Reshape input tensors into 3D tensors by flattening the batch
+ // dimensions. This makes it easier to unroll the batch dimension.
+ auto x_flat =
+ builder->Reshape(x, {batch_count, x_shape->dimensions(ndims - 2),
+ x_shape->dimensions(ndims - 1)});
+ auto y_flat =
+ builder->Reshape(y, {batch_count, y_shape->dimensions(ndims - 2),
+ y_shape->dimensions(ndims - 1)});
+
+ // Slice batches into individual matrices and multiply them.
+ std::vector<xla::ComputationDataHandle> out_slices;
+ for (int64 i = 0; i < batch_count; ++i) {
+ // Slice off individual matrices and reshape to 2D tensors.
+ auto x_slice = builder->Slice(
+ x_flat, {i, 0, 0},
+ {i + 1, x_shape->dimensions(ndims - 2), x_shape->dimensions(ndims - 1)},
+ {1, 1, 1});
+ x_slice = builder->Reshape(x_slice, {x_shape->dimensions(ndims - 2),
+ x_shape->dimensions(ndims - 1)});
+ auto y_slice = builder->Slice(
+ y_flat, {i, 0, 0},
+ {i + 1, y_shape->dimensions(ndims - 2), y_shape->dimensions(ndims - 1)},
+ {1, 1, 1});
+ y_slice = builder->Reshape(y_slice, {y_shape->dimensions(ndims - 2),
+ y_shape->dimensions(ndims - 1)});
+
+ // Transpose if needed.
+ auto lhs = transpose_x ? builder->Transpose(x_slice, {1, 0}) : x_slice;
+ auto rhs = transpose_y ? builder->Transpose(y_slice, {1, 0}) : y_slice;
+
+ // Multiply matrices and add an outer singleton dimension to the output
+ // so we can concatenate along the flattened batch dimension later.
+ auto out = builder->Dot(lhs, rhs);
+ out = builder->Reshape(out,
+ {1, dimensions[ndims - 2], dimensions[ndims - 1]});
+ out_slices.push_back(out);
+ }
+
+ // Concatenate output slices and reshape to original number of dimensions.
+ xla::ComputationDataHandle data;
+ if (out_slices.empty()) {
+ // It is illegal to pass an empty list to ConcatInDim.
+ // The batch count is empty, so both inputs must have zero elements.
+ // Arbitrarily use the left input as the argument to Reshape().
+ data = x;
+ } else {
+ data = builder->ConcatInDim(out_slices, 0);
+ }
+ return builder->Reshape(data, dimensions);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h
new file mode 100644
index 0000000000..b46bc7417d
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h
@@ -0,0 +1,51 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+
+namespace tensorflow {
+
+// Multiplies slices of two tensors in batches.
+
+// Multiplies all slices of `Tensor` `x` and `y` (each slice can be
+// viewed as an element of a batch), and arranges the individual results
+// in a single output tensor of the same batch size. Each of the
+// individual slices can optionally be transposed before multiplication by
+// setting the `transpose_x` or `transpose_y` flag to `true`.
+//
+// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
+// and `[..., r_y, c_y]`.
+//
+// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
+//
+// r_o = c_x if transpose_x else r_x
+// c_o = r_y if transpose_y else c_y
+//
+// It is computed as:
+//
+// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
+// TODO(phawkins): add an option to take the complex conjugate of the LHS or
+// RHS.
+xla::StatusOr<xla::ComputationDataHandle> BatchDot(
+ xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
+ xla::ComputationDataHandle y, bool transpose_x, bool transpose_y);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
new file mode 100644
index 0000000000..b3cc489adf
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -0,0 +1,166 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/cholesky.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
+#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h"
+#include "tensorflow/compiler/tf2xla/lib/util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+namespace {
+
+// def cholesky_unblocked(a):
+// assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1]
+// n = a.shape[-2]
+// l = np.zeros_like(a)
+// for j in xrange(n):
+// r = l[..., j, :j]
+// l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(r, r))
+// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j],
+// np.transpose(r))) / l[..., j, j]
+// return l
+xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(a));
+ xla::ComputationDataHandle l = Zeros(builder, *shape);
+ const int64 n = xla::ShapeUtil::GetDimension(*shape, -2);
+ for (int j = 0; j < n; ++j) {
+ // Picture of block structure:
+ // ... \
+ // \
+ // -- r -- d
+ // |\
+ // B c \
+ // | \
+ // | ...
+ //
+ // ^
+ // column j
+ TF_ASSIGN_OR_RETURN(auto d,
+ SliceInMinorDims(builder, a, {j, j}, {j + 1, j + 1}));
+ TF_ASSIGN_OR_RETURN(auto c,
+ SliceInMinorDims(builder, a, {j + 1, j}, {n, j + 1}));
+ xla::ComputationDataHandle new_d_squared = d;
+ xla::ComputationDataHandle br;
+ if (j > 0) {
+ TF_ASSIGN_OR_RETURN(auto r,
+ SliceInMinorDims(builder, l, {j, 0}, {j + 1, j}));
+ TF_ASSIGN_OR_RETURN(auto b,
+ SliceInMinorDims(builder, l, {j + 1, 0}, {n, j}));
+ TF_ASSIGN_OR_RETURN(auto r_squared,
+ BatchDot(builder, r, r, /*transpose_x=*/false,
+ /*transpose_y=*/true));
+ new_d_squared = builder->Sub(new_d_squared, r_squared);
+
+ TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false,
+ /*transpose_y=*/true));
+ }
+ auto new_d_inv = builder->Pow(
+ new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5));
+ auto new_d = builder->Mul(new_d_inv, new_d_squared);
+ TF_ASSIGN_OR_RETURN(l, UpdateSliceInMinorDims(builder, l, new_d, {j, j}));
+
+ if (j > 0) {
+ c = builder->Sub(c, br);
+ }
+ auto new_c = builder->Mul(c, new_d_inv);
+ TF_ASSIGN_OR_RETURN(l,
+ UpdateSliceInMinorDims(builder, l, new_c, {j + 1, j}));
+ }
+ return l;
+}
+
+} // namespace
+
+xla::StatusOr<xla::ComputationDataHandle> Cholesky(
+ xla::ComputationBuilder* builder, xla::ComputationDataHandle a,
+ int64 block_size) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
+ builder->GetShape(a));
+ const int ndims = xla::ShapeUtil::Rank(*a_shape);
+ if (ndims < 2) {
+ return errors::InvalidArgument(
+ "Arguments to Cholesky must have rank >= 2: ", ndims);
+ }
+
+ const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1);
+ if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) {
+ return errors::InvalidArgument(
+ "Arguments to Cholesky must be square matrices: ",
+ xla::ShapeUtil::HumanString(*a_shape));
+ }
+
+ if (block_size < 1) {
+ return errors::InvalidArgument(
+ "block_size argument to Cholesky must be >= 1; got ", block_size);
+ }
+
+ // Blocked left-looking Cholesky factorization.
+ // Algorithm 1 from
+ // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only
+ // execution." Proceedings of General Purpose GPUs. ACM, 2017.
+ xla::ComputationDataHandle l = Zeros(builder, *a_shape);
+ for (int64 i = 0; i < n; i += block_size) {
+ int64 k = std::min(block_size, n - i);
+ if (i > 0) {
+ // TODO(phawkins): consider implementing SYRK for the diagonal part of
+ // the panel.
+ // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
+ TF_ASSIGN_OR_RETURN(auto lhs,
+ SliceInMinorDims(builder, l, {i, 0}, {n, i}));
+ TF_ASSIGN_OR_RETURN(auto rhs,
+ SliceInMinorDims(builder, l, {i, 0}, {i + k, i}));
+ TF_ASSIGN_OR_RETURN(auto delta,
+ BatchDot(builder, lhs, rhs, /*transpose_x=*/false,
+ /*transpose_y=*/true));
+ TF_ASSIGN_OR_RETURN(auto before,
+ SliceInMinorDims(builder, a, {i, i}, {n, i + k}));
+ TF_ASSIGN_OR_RETURN(
+ a, UpdateSliceInMinorDims(builder, a, builder->Sub(before, delta),
+ {i, i}));
+ }
+
+ // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
+ TF_ASSIGN_OR_RETURN(auto x,
+ SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
+ TF_ASSIGN_OR_RETURN(auto factorized, CholeskyUnblocked(builder, x));
+ TF_ASSIGN_OR_RETURN(l,
+ UpdateSliceInMinorDims(builder, l, factorized, {i, i}));
+
+ if (i + k < n) {
+ // l[i+k:, i:i+k] = trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k])
+ TF_ASSIGN_OR_RETURN(auto panel,
+ SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
+ TF_ASSIGN_OR_RETURN(auto update,
+ TriangularSolve(builder, factorized, panel,
+ /*block_size=*/8));
+ TF_ASSIGN_OR_RETURN(
+ l, UpdateSliceInMinorDims(builder, l, update, {i + k, i}));
+ }
+ }
+ return l;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h
new file mode 100644
index 0000000000..2bead7359b
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.h
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+
+namespace tensorflow {
+
+// Computes the Cholesky decompositions of a batch of symmetric positive
+// definite matrices.
+// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
+// two minor dimensions equal.
+// The algorithm implements a blocked Cholesky decomposition; `block_size` is
+// the block size to use.
+// TODO(phawkins): check for negative values on the diagonal and return an
+// error, instead of silently yielding NaNs.
+xla::StatusOr<xla::ComputationDataHandle> Cholesky(
+ xla::ComputationBuilder* builder, xla::ComputationDataHandle a,
+ int64 block_size = 256);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
new file mode 100644
index 0000000000..579944c3a3
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -0,0 +1,175 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
+#include "tensorflow/compiler/tf2xla/lib/util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
+ xla::ComputationDataHandle b, int64 block_size) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
+ builder->GetShape(a));
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape,
+ builder->GetShape(b));
+ if (xla::ShapeUtil::Rank(*a_shape) != xla::ShapeUtil::Rank(*b_shape)) {
+ return errors::InvalidArgument(
+ "Arguments to TriangularSolve have different ranks: ",
+ xla::ShapeUtil::HumanString(*a_shape), " vs. ",
+ xla::ShapeUtil::HumanString(*b_shape));
+ }
+ const int ndims = xla::ShapeUtil::Rank(*a_shape);
+ if (ndims < 2) {
+ return errors::InvalidArgument(
+ "Arguments to TriangularSolve must have rank >= 2: ", ndims);
+ }
+ // The batch dimensions must be equal.
+ std::vector<int64> batch_dimensions;
+ for (int i = 0; i < ndims - 2; ++i) {
+ int64 a_size = a_shape->dimensions(i);
+ int64 b_size = b_shape->dimensions(i);
+ if (a_size != b_size) {
+ return errors::InvalidArgument(
+ "Batch dimensions of arguments to TriangularSolve must be equal: ",
+ xla::ShapeUtil::HumanString(*a_shape), " vs ",
+ xla::ShapeUtil::HumanString(*b_shape));
+ }
+ batch_dimensions.push_back(a_size);
+ }
+
+ const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1);
+ const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
+ if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) {
+ return errors::InvalidArgument(
+ "The 'a' arguments to TriangularSolve must be square matrices: ",
+ xla::ShapeUtil::HumanString(*a_shape));
+ }
+ if (n != xla::ShapeUtil::GetDimension(*b_shape, -1)) {
+ return errors::InvalidArgument(
+ "Arguments to TriangularSolve have incompatible matrix shapes: ",
+ xla::ShapeUtil::HumanString(*a_shape), " vs ",
+ xla::ShapeUtil::HumanString(*b_shape));
+ }
+
+ if (block_size < 1) {
+ return errors::InvalidArgument(
+ "block_size argument to TriangularSolve must be >= 1; got ",
+ block_size);
+ }
+
+ // Returns [b1, b2, ... , bn, indices[0], indices[1]].
+ auto prepend_batch_dims = [&](std::array<int64, 2> indices) {
+ std::vector<int64> output(ndims);
+ std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin());
+ std::copy(indices.begin(), indices.end(),
+ output.begin() + batch_dimensions.size());
+ return output;
+ };
+
+ std::map<int, xla::Computation> base_computations;
+ auto get_base_triangular_solve =
+ [&](int k) -> xla::StatusOr<xla::Computation*> {
+ xla::Computation& computation = base_computations[k];
+ if (computation.IsNull()) {
+ std::unique_ptr<xla::ComputationBuilder> sub = builder->CreateSubBuilder(
+ tensorflow::strings::StrCat("trsm_base_", k));
+
+ auto a_param =
+ sub->Parameter(0,
+ xla::ShapeUtil::MakeShape(b_shape->element_type(),
+ prepend_batch_dims({k, k})),
+ "a");
+
+ auto b_param =
+ sub->Parameter(1,
+ xla::ShapeUtil::MakeShape(b_shape->element_type(),
+ prepend_batch_dims({m, k})),
+ "b");
+
+ // TODO(phawkins): it might make sense to use a while loop here, rather
+ // than unrolling.
+ // TODO(phawkins): the left-looking variant of the algorithm might be more
+ // efficient at block size 1.
+ TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param,
+ /*block_size=*/1)
+ .status());
+
+ TF_ASSIGN_OR_RETURN(computation, sub->Build());
+ }
+ return &computation;
+ };
+
+ xla::ComputationDataHandle output = Zeros(builder, *b_shape);
+
+ // Right-looking blocked triangular solve.
+ // For an explanation of the algorithm, see the TRSM discussion in:
+ // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation
+ // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1
+ // (2008): 4.
+ for (int64 i = 0; i < n; i += block_size) {
+ int64 k = std::min(block_size, n - i);
+
+ // if k > 1:
+ // output[..., :, i:i+k] = triangular_solve(
+ // a[..., i:i+k, ..., i:i+k], b[..., :, i:i+k], side='Right',
+ // kind='Lower', transpose=True, block_size=1)
+ // else:
+ // output[..., :, i] = b[..., :, i] / a[..., i, i]
+ TF_ASSIGN_OR_RETURN(auto a_slice,
+ SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
+ TF_ASSIGN_OR_RETURN(auto b_slice,
+ SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
+ xla::ComputationDataHandle update;
+ if (k > 1) {
+ TF_ASSIGN_OR_RETURN(xla::Computation * solve,
+ get_base_triangular_solve(k));
+ update = builder->Call(*solve, {a_slice, b_slice});
+ } else {
+ update = builder->Div(b_slice, a_slice);
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
+ // b[..., :, i+k:] -= np.dot(output[..., :, i:i+k],
+ // np.transpose(..., a[i+k:, i:i+k]))
+ if (i + k < n) {
+ TF_ASSIGN_OR_RETURN(auto a_slice_2,
+ SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
+ TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, update, a_slice_2,
+ /*transpose_x=*/false,
+ /*transpose_y=*/true));
+
+ TF_ASSIGN_OR_RETURN(auto b_slice_2,
+ SliceInMinorDims(builder, b, {0, i + k}, {m, n}));
+ b_update = builder->Sub(b_slice_2, b_update);
+ TF_ASSIGN_OR_RETURN(
+ b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k}));
+ }
+ }
+ return output;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
new file mode 100644
index 0000000000..501d026411
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+
+namespace tensorflow {
+
+// Solves systems of linear equations with upper or lower triangular matrices by
+// backsubstitution.
+//
+// `a` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form
+// square matrices. The strictly upper triangular part of each inner-most matrix
+// is assumed to be zero and not accessed.
+// `b` is a tensor of shape `[..., M, K]`.
+//
+// The innermost matrices in the output satisfy matrix equations
+// `output[..., i, j] * adjoint(a[..., k, j]) = b[..., i, k]`.
+//
+// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no
+// blocking is used.
+// TODO(phawkins): equivalent to the BLAS TRSM routine with side=right,
+// kind=lower, and transposed_a=true. Implement the other possible combinations
+// of side, kind and transposed_a.
+xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
+ xla::ComputationDataHandle b, int64 block_size = 256);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
new file mode 100644
index 0000000000..671d9aa4fe
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h"
+
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+using TriangularSolveTest = xla::ClientLibraryTestBase;
+
+XLA_TEST_F(TriangularSolveTest, Simple) {
+ xla::ComputationBuilder builder(client_, TestName());
+
+ xla::Array2D<float> a_vals({
+ {2, 0, 0, 0},
+ {3, 6, 0, 0},
+ {4, 7, 9, 0},
+ {5, 8, 10, 11},
+ });
+ xla::Array2D<float> b_vals({
+ {1, 2, 3, 4},
+ {5, 6, 7, 8},
+ {9, 10, 11, 12},
+ });
+
+ xla::ComputationDataHandle a, b;
+ auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
+ auto b_data = CreateR2Parameter<float>(b_vals, 1, "b", &builder, &b);
+ auto result = TriangularSolve(&builder, a, b, /*block_size=*/2);
+ TF_ASSERT_OK(result.status());
+
+ xla::Array2D<float> expected({
+ {0.5, 0.08333334, 0.04629629, 0.03367003},
+ {2.5, -0.25, -0.1388889, -0.1010101},
+ {4.5, -0.58333331, -0.32407406, -0.23569024},
+ });
+
+ ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
+ xla::ErrorSpec(2e-3, 2e-3));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
new file mode 100644
index 0000000000..7ffe0aa6df
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -0,0 +1,107 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/util.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder,
+ xla::Shape& shape) {
+ return builder->Broadcast(
+ builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())),
+ xla::AsInt64Slice(shape.dimensions()));
+}
+
+xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
+ xla::PrimitiveType type, double value) {
+ switch (type) {
+ case xla::F16:
+ return builder->ConstantR0<xla::half>(static_cast<xla::half>(value));
+ break;
+ case xla::F32:
+ return builder->ConstantR0<float>(static_cast<float>(value));
+ break;
+ case xla::F64:
+ return builder->ConstantR0<double>(value);
+ break;
+ case xla::C64:
+ return builder->ConstantR0<xla::complex64>(value);
+ break;
+ default:
+ LOG(FATAL) << "unhandled element type " << type;
+ }
+}
+
+xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ gtl::ArraySlice<int64> start, gtl::ArraySlice<int64> end) {
+ TF_RET_CHECK(start.size() == end.size());
+ int64 n_minor_dims = start.size();
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
+
+ const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+ TF_RET_CHECK(n_minor_dims <= n_dims);
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape->dimensions()),
+ /*pos=*/0,
+ /*len=*/n_dims - n_minor_dims);
+
+ // Prepends 0s in the major dim
+ std::vector<int64> padded_start(n_dims, 0);
+ std::copy(start.begin(), start.end(),
+ padded_start.begin() + major_dims.size());
+
+ // Prepends the shape of the major dims.
+ std::vector<int64> padded_end(n_dims);
+ std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
+ std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
+
+ std::vector<int64> strides(n_dims, 1);
+ return builder->Slice(x, padded_start, padded_end, strides);
+}
+
+xla::StatusOr<xla::ComputationDataHandle> UpdateSlice(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start) {
+ // TODO(phawkins): make int64 work on all backends, remove the int32 cast.
+ std::vector<int32> start_as_int32(start.begin(), start.end());
+ return builder->DynamicUpdateSlice(
+ x, update, builder->ConstantR1<int32>(start_as_int32));
+}
+
+xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+ const int64 n_minor_dims = start.size();
+ TF_RET_CHECK(n_minor_dims <= n_dims);
+ std::vector<int64> padded_start(n_dims, 0);
+ std::copy(start.begin(), start.end(),
+ padded_start.begin() + (n_dims - n_minor_dims));
+ return UpdateSlice(builder, x, update, padded_start);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h
new file mode 100644
index 0000000000..8fba6b5cf2
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/util.h
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+
+// Returns a zero-filled tensor with shape `shape`.
+xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder,
+ xla::Shape& shape);
+
+// Returns a floating point scalar constant of 'type' with 'value'.
+// If 'type' is complex, returns a real value with zero imaginary component.
+xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
+ xla::PrimitiveType type, double value);
+
+// Performs a slice in the minor dimensions of a Tensor.
+xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ gtl::ArraySlice<int64> start, gtl::ArraySlice<int64> end);
+
+// Updates a slice of 'x', i.e.,
+// x[start[0], ..., start[n]] = update
+xla::StatusOr<xla::ComputationDataHandle> UpdateSlice(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start);
+
+// Updates a slice of 'x', where 'start' contains a list of minor dimensions:
+// x[..., start[0], ..., start[n]] = update
+xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc
index 1efbe0ffb1..c969212a1b 100644
--- a/tensorflow/compiler/tf2xla/type_util.cc
+++ b/tensorflow/compiler/tf2xla/type_util.cc
@@ -49,6 +49,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
case tensorflow::DT_UINT64:
*type = xla::U64;
return Status::OK();
+ case tensorflow::DT_BFLOAT16:
+ *type = xla::BF16;
+ return Status::OK();
case tensorflow::DT_HALF:
*type = xla::F16;
return Status::OK();
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 4d40ca5825..ac7d4cfb12 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -236,12 +236,6 @@ class XlaCompiler {
// to the computation.
bool allow_cpu_custom_calls = false;
- // If 'local_executable_has_hybrid_result', the top-level pointers of the
- // result tuple of compiled programs are stored in host memory and the
- // nested buffers in device memory, otherwise the whole result tuple is
- // stored in device memory.
- bool local_executable_has_hybrid_result = false;
-
// If not nullptr, populate_resource_manager is called with the
// compilation device's resource manager when the compilation
// device is created, and can be used to create metadata objects
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 1df6173275..9c3e15d2fa 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -16,6 +16,7 @@ limitations under the License.
// This file defines helper routines for Tla JIT compilation.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
@@ -185,25 +186,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b,
DataType data_type,
double value) {
- xla::Literal literal;
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- switch (type) {
- case xla::F16:
- return b->ConstantR0<xla::half>(static_cast<xla::half>(value));
- break;
- case xla::F32:
- return b->ConstantR0<float>(static_cast<float>(value));
- break;
- case xla::F64:
- return b->ConstantR0<double>(value);
- break;
- case xla::C64:
- return b->ConstantR0<complex64>(value);
- break;
- default:
- LOG(FATAL) << "unhandled element type " << type;
- }
+ return ::tensorflow::FloatLiteral(b, type, value);
}
/* static */ Status XlaHelpers::ReshapeLiteral(
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 660f419e46..515b572b0e 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -77,6 +77,7 @@ cc_library(
hdrs = ["types.h"],
visibility = [":friends"],
deps = [
+ "//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
"//third_party/eigen3",
],
@@ -339,6 +340,7 @@ cc_library(
name = "array",
hdrs = ["array.h"],
deps = [
+ ":status",
":types",
"//tensorflow/core:lib",
],
diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h
index ba898d1f4e..213e0bac6c 100644
--- a/tensorflow/compiler/xla/array.h
+++ b/tensorflow/compiler/xla/array.h
@@ -23,8 +23,10 @@ limitations under the License.
#include <iterator>
#include <memory>
#include <random>
+#include <type_traits>
#include <vector>
+#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -35,10 +37,63 @@ limitations under the License.
namespace xla {
+namespace array_impl {
+
+// conjunction
+//
+// Performs a compile-time logical AND operation on the passed types (which
+// must have `::value` members convertible to `bool`. Short-circuits if it
+// encounters any `false` members (and does not compare the `::value` members
+// of any remaining arguments).
+//
+// This metafunction is designed to be a drop-in replacement for the C++17
+// `std::conjunction` metafunction.
+template <typename... Ts>
+struct conjunction;
+
+template <typename T, typename... Ts>
+struct conjunction<T, Ts...>
+ : std::conditional<T::value, conjunction<Ts...>, T>::type {};
+
+template <>
+struct conjunction<> : std::true_type {};
+
+// A type trait that is valid when all elements in a parameter pack are of
+// integral type.
+template <typename... T>
+using pack_is_integral = conjunction<std::is_integral<T>...>;
+
+// Compares three same-sized vectors elementwise. For each item in `values`,
+// returns false if any of values[i] is outside the half-open range [starts[i],
+// ends[i]).
+template <typename C1, typename C2, typename C3>
+bool all_inside_range(const C1& values, const C2& range_starts,
+ const C3& range_ends) {
+ for (size_t i = 0, e = values.size(); i < e; ++i) {
+ if (values[i] < range_starts[i] || values[i] >= range_ends[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace array_impl
+
// General N dimensional array class with arbitrary value type.
template <typename T>
class Array {
public:
+ // Type inference can have a hard time parsing very deep initializer list
+ // nests, especially if one or more dimensions is one as the compiler just
+ // sees a single-element integer initializer. These typedefs allow casting
+ // explicitly with less typing.
+ using InitializerList1D = std::initializer_list<T>;
+ using InitializerList2D = std::initializer_list<InitializerList1D>;
+ using InitializerList3D = std::initializer_list<InitializerList2D>;
+ using InitializerList4D = std::initializer_list<InitializerList3D>;
+
+ using value_type = T;
+
// Creates a new array with the specified dimensions.
explicit Array(tensorflow::gtl::ArraySlice<int64> sizes)
: Array(sizes, T()) {}
@@ -53,7 +108,7 @@ class Array {
// Creates a 2D array from the given nested initializer list. The outer
// initializer list is the first dimension, the inner is the second dimension.
// For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3.
- Array(std::initializer_list<std::initializer_list<T>> values)
+ Array(InitializerList2D values)
: Array(ToInt64Vector({values.size(), values.begin()->size()})) {
int64 idx = 0;
for (const auto& it1 : values) {
@@ -67,8 +122,7 @@ class Array {
// Creates a 3D array from the given nested initializer list. The outer
// initializer list is the first dimension, and so on.
- Array(std::initializer_list<std::initializer_list<std::initializer_list<T>>>
- values)
+ Array(InitializerList3D values)
: Array(ToInt64Vector({values.size(), values.begin()->size(),
values.begin()->begin()->size()})) {
int64 idx = 0;
@@ -85,9 +139,7 @@ class Array {
// Creates a 4D array from the given nested initializer list. The outer
// initializer list is the first dimension, and so on.
- Array(std::initializer_list<
- std::initializer_list<std::initializer_list<std::initializer_list<T>>>>
- values)
+ Array(InitializerList4D values)
: Array(ToInt64Vector({values.size(), values.begin()->size(),
values.begin()->begin()->size(),
values.begin()->begin()->begin()->size()})) {
@@ -173,10 +225,46 @@ class Array {
}
}
+ // Invokes a callback with the (indices, value_ptr) for each cell in the
+ // array. If a callback returns a non-OK status, returns that else returns
+ // Status::OK().
+ Status EachStatus(
+ std::function<Status(tensorflow::gtl::ArraySlice<int64>, T*)> f) {
+ std::vector<int64> index(sizes_.size());
+ for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
+ Status s = f(index, &values_[i]);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ return Status::OK();
+ }
+
+ // Invokes a callback with the (indices, value) for each cell in the array.
+ // If a callback returns a non-OK status, returns that else returns
+ // Status::OK().
+ Status EachStatus(
+ std::function<Status(tensorflow::gtl::ArraySlice<int64>, T)> f) const {
+ std::vector<int64> index(sizes_.size());
+ for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
+ Status s = f(index, values_[i]);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ return Status::OK();
+ }
+
// Returns the value at the cell specified by the indexes. The number of
// arguments have to match with the number of dimensions for the array.
+ //
+ // The type trait is required to avoid this overload participating too
+ // eagerly; a parameter pack can take zero or more elements, so we must
+ // restrict this to only parameter packs that are all of integral type.
template <typename... Dims>
- const T& operator()(Dims... dims) const {
+ typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
+ const T&>::type
+ operator()(Dims... dims) const {
// We are using a std::array to avoid having to allocate memory in this
// function for performance reasons.
std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
@@ -186,7 +274,9 @@ class Array {
// Returns the value at the cell specified by the indexes. The number of
// arguments have to match with the number of dimensions for the array.
template <typename... Dims>
- T& operator()(Dims... dims) {
+ typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
+ T&>::type
+ operator()(Dims... dims) {
// We are using a std::array to avoid having to allocate memory in this
// function for performance reasons.
std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
@@ -255,6 +345,59 @@ class Array {
bool operator!=(const Array<T>& other) const { return !(*this == other); }
+ // Performs the equivalent of a slice operation on this array.
+ Array<T> Slice(tensorflow::gtl::ArraySlice<int64> starts,
+ tensorflow::gtl::ArraySlice<int64> limits) const {
+ CHECK_EQ(starts.size(), num_dimensions());
+ CHECK_EQ(limits.size(), num_dimensions());
+
+ std::vector<int64> sizes;
+ std::transform(starts.begin(), starts.end(), limits.begin(),
+ std::back_inserter(sizes),
+ [](int64 start, int64 limit) { return limit - start; });
+ Array<T> result(sizes);
+
+ std::vector<int64> index(sizes_.size());
+ int64 slice_i = 0;
+ for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
+ if (array_impl::all_inside_range(index, starts, limits)) {
+ // Even though the bounds of result are different to our bounds, we're
+ // iterating in the same order. So we can simply write successive linear
+ // indices instead of recalculating a multi-dimensional index.
+ result.values_[slice_i++] = values_[i];
+ }
+ }
+ return result;
+ }
+
+ // Performs the equivalent of a DynamicUpdateSlice in-place on this array.
+ void UpdateSlice(const Array<T>& from,
+ tensorflow::gtl::ArraySlice<int64> start_indices) {
+ CHECK_EQ(from.num_dimensions(), num_dimensions());
+ std::vector<int64> limit_indices;
+ std::transform(start_indices.begin(), start_indices.end(),
+ from.dimensions().begin(), std::back_inserter(limit_indices),
+ std::plus<int64>{});
+ std::vector<int64> index(sizes_.size());
+ int64 from_i = 0;
+ for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
+ if (array_impl::all_inside_range(index, start_indices, limit_indices)) {
+ // Even though the bounds of from are different to our bounds, we're
+ // iterating in the same order. So we can simply write successive linear
+ // indices instead of recalculating a multi-dimensional index.
+ values_[i] = from.values_[from_i++];
+ }
+ }
+ }
+
+ // Performs an in-place reshape, modifying the dimensions but not the
+ // underlying data.
+ void Reshape(tensorflow::gtl::ArraySlice<int64> new_dimensions) {
+ int64 old_num_elements = num_elements();
+ sizes_ = std::vector<int64>(new_dimensions.begin(), new_dimensions.end());
+ CHECK_EQ(num_elements(), old_num_elements);
+ }
+
// Returns a string representation of the array suitable for debugging.
string ToString() const {
std::vector<string> pieces;
diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc
index 093784f541..8b94194774 100644
--- a/tensorflow/compiler/xla/array_test.cc
+++ b/tensorflow/compiler/xla/array_test.cc
@@ -71,6 +71,19 @@ TEST(ArrayTest, IndexingReadWrite) {
EXPECT_EQ(arr(1, 2), 61);
}
+TEST(ArrayTest, DynamicIndexingReadWrite) {
+ Array<int> arr({2, 3});
+
+ std::vector<int64> index1 = {1, 1};
+ std::vector<int64> index2 = {1, 2};
+ EXPECT_EQ(arr(index1), 0);
+ EXPECT_EQ(arr(index2), 0);
+ arr(index1) = 51;
+ arr(index2) = 61;
+ EXPECT_EQ(arr(1, 1), 51);
+ EXPECT_EQ(arr(1, 2), 61);
+}
+
TEST(ArrayTest, IndexingReadWriteBool) {
Array<bool> arr{{false, true, false}, {false, true, false}};
@@ -141,5 +154,37 @@ TEST(ArrayTest, Each) {
EXPECT_EQ(arr.num_elements() * (arr.num_elements() - 1) / 2, each_sum);
}
+TEST(ArrayTest, Slice) {
+ Array<int64> arr({2, 4});
+ arr.FillWithMultiples(1);
+
+ Array<int64> identity_slice = arr.Slice({0, 0}, {2, 4});
+ EXPECT_EQ(identity_slice.dimensions(), arr.dimensions());
+ for (auto it1 = arr.begin(), it2 = identity_slice.begin(), e = arr.end();
+ it1 != e; ++it1, ++it2) {
+ EXPECT_EQ(*it1, *it2);
+ }
+
+ Array<int64> sub_slice = arr.Slice({1, 0}, {2, 2});
+ EXPECT_EQ(sub_slice.dimensions(), (std::vector<int64>{1, 2}));
+ const string expected = R"([[4, 5]])";
+ EXPECT_EQ(expected, sub_slice.ToString());
+}
+
+TEST(ArrayTest, UpdateSlice) {
+ Array<int64> arr({3, 4});
+ arr.FillWithMultiples(1);
+
+ Array<int64> sub_arr({2, 2});
+ sub_arr.FillWithMultiples(3);
+
+ arr.UpdateSlice(sub_arr, {1, 1});
+
+ const string expected = R"([[0, 1, 2, 3],
+ [4, 0, 3, 7],
+ [8, 6, 9, 11]])";
+ EXPECT_EQ(expected, arr.ToString());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 92cd8e729d..66937d64af 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -142,8 +142,7 @@ StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
"TransferToClient request");
}
- Literal literal(response.literal());
- return MakeUnique<Literal>(literal);
+ return MakeUnique<Literal>(response.literal());
}
Status Client::ResetDevice() {
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 8e1b4be1f3..4c6e320557 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -68,6 +68,7 @@ class ShardingBuilder {
const TileAssignment& tile_assignment) {
OpSharding result;
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
+ *result.mutable_tile_shape() = tile_shape;
for (int64 dim : tile_assignment.dimensions()) {
result.add_tile_assignment_dimensions(dim);
}
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index ee34682087..fca2bf2688 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -44,6 +44,7 @@ cc_library(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index e6645e4941..d936bd870b 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -48,62 +49,6 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
} // namespace
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
- if (ShapeUtil::IsTuple(shape)) {
- std::vector<std::unique_ptr<Literal>> elements;
- for (const Shape& element_shape : shape.tuple_shapes()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
- MakeFakeLiteral(element_shape));
- elements.push_back(std::move(element));
- }
- return Literal::MakeTupleOwned(std::move(elements));
- }
- std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
- std::minstd_rand0 engine;
- switch (shape.element_type()) {
- case F32: {
- std::uniform_real_distribution<float> generator(0.0f, 1.0f);
- TF_CHECK_OK(literal->Populate<float>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return generator(engine);
- }));
- break;
- }
- case S32: {
- std::uniform_int_distribution<int32> generator(
- std::numeric_limits<int32>::lowest(),
- std::numeric_limits<int32>::max());
- TF_CHECK_OK(literal->Populate<int32>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return generator(engine);
- }));
- break;
- }
- case S64: {
- std::uniform_int_distribution<int64> generator(
- std::numeric_limits<int64>::lowest(),
- std::numeric_limits<int64>::max());
- TF_CHECK_OK(literal->Populate<int64>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return generator(engine);
- }));
- break;
- }
- case PRED: {
- std::uniform_int_distribution<int> generator(0, 1);
- TF_CHECK_OK(literal->Populate<bool>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return generator(engine);
- }));
- break;
- }
- default:
- return Unimplemented("Unsupported type for fake literal generation: %s",
- ShapeUtil::HumanString(shape).c_str());
- }
- return std::move(literal);
-}
-
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
Client* client) {
if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) {
diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h
index b5c4393dcc..7e640d1307 100644
--- a/tensorflow/compiler/xla/client/lib/testing.h
+++ b/tensorflow/compiler/xla/client/lib/testing.h
@@ -26,10 +26,6 @@ limitations under the License.
namespace xla {
-// Generates fake data in a literal of the given shape, or returns an error
-// status if the element type is currently unhandled for fake data generation.
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
-
// Generates fake data of the given shape on the device or dies. The fake data
// is created by performing a computation on the device rather than transferring
// data from the host to the device.
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 15c744ecd3..b50425a09c 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -27,16 +27,6 @@ namespace se = ::perftools::gputools;
namespace xla {
-ExecutableBuildOptions& ExecutableBuildOptions::set_platform(
- perftools::gputools::Platform* platform) {
- platform_ = platform;
- return *this;
-}
-
-perftools::gputools::Platform* ExecutableBuildOptions::platform() const {
- return platform_;
-}
-
ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal(
int device_ordinal) {
device_ordinal_ = device_ordinal;
@@ -56,16 +46,6 @@ const Shape* ExecutableBuildOptions::result_layout() const {
return result_layout_set_ ? &result_layout_ : nullptr;
}
-ExecutableBuildOptions& ExecutableBuildOptions::set_has_hybrid_result(
- bool has_hybrid_result) {
- has_hybrid_result_ = has_hybrid_result;
- return *this;
-}
-
-bool ExecutableBuildOptions::has_hybrid_result() const {
- return has_hybrid_result_;
-}
-
namespace {
StatusOr<Backend::StreamPtr> BorrowStreamForDevice(int device_ordinal,
Backend* backend) {
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 9f985ed527..e9eeaa0aa2 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -37,14 +37,6 @@ namespace xla {
// LocalClient::Compile.
class ExecutableBuildOptions {
public:
- // If set, this is the platform to build the computation for. This must match
- // the underlying platform of the service. A value of nullptr indicates the
- // option has not been set.
- //
- // TODO(b/28616830): Support multiple platforms.
- ExecutableBuildOptions& set_platform(perftools::gputools::Platform* platform);
- perftools::gputools::Platform* platform() const;
-
// If set, this is the device to build the computation for. Valid
// device_ordinal values are: 0 to # of devices - 1. These values are
// identical to the device ordinal values used by StreamExecutor. The built
@@ -61,18 +53,10 @@ class ExecutableBuildOptions {
ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout);
const Shape* result_layout() const;
- // If set, the executable will be built to output a hybrid
- // ShapedBuffer with top-level tuple pointers in host memory and
- // result buffers in device memory.
- ExecutableBuildOptions& set_has_hybrid_result(bool has_hybrid_result);
- bool has_hybrid_result() const;
-
private:
- perftools::gputools::Platform* platform_ = nullptr;
int device_ordinal_ = -1;
Shape result_layout_;
bool result_layout_set_ = false;
- bool has_hybrid_result_ = true;
};
class LocalExecutable {
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index fda791401d..93d3cd425f 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -33,6 +33,20 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
+namespace {
+using tensorflow::int64;
+
+constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
+
+// Converts between little and big endian, assuming elements in the array are 16
+// bits long.
+void ConvertEndianShort(char* bytes, int64 size) {
+ CHECK_EQ(size / 2, 0);
+ for (int64 i = 0; i < size; i += 2) {
+ std::swap(bytes[i], bytes[i + 1]);
+ }
+}
+} // namespace
namespace xla {
@@ -169,6 +183,8 @@ Status Literal::Copy(const Literal& src_literal,
return CopyRange<int64>(src_literal, src_base, dest_base, copy_size);
case F16:
return CopyRange<half>(src_literal, src_base, dest_base, copy_size);
+ case BF16:
+ return CopyRange<bfloat16>(src_literal, src_base, dest_base, copy_size);
case F32:
return CopyRange<float>(src_literal, src_base, dest_base, copy_size);
case F64:
@@ -200,6 +216,8 @@ Status Literal::Copy(const Literal& src_literal,
return *Literal::CreateR0<int64>(0);
case F16:
return *Literal::CreateR0<half>(static_cast<half>(0.0f));
+ case BF16:
+ return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
case F32:
return *Literal::CreateR0<float>(0);
case F64:
@@ -285,6 +303,9 @@ Status Literal::Copy(const Literal& src_literal,
case F16:
return *Literal::CreateR0<half>(
static_cast<half>(-std::numeric_limits<float>::infinity()));
+ case BF16:
+ return *Literal::CreateR0<bfloat16>(
+ static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no minimum value";
case OPAQUE:
@@ -321,6 +342,9 @@ Status Literal::Copy(const Literal& src_literal,
case F16:
return *Literal::CreateR0<half>(
static_cast<half>(std::numeric_limits<float>::infinity()));
+ case BF16:
+ return *Literal::CreateR0<bfloat16>(
+ static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no maximum value";
case OPAQUE:
@@ -428,6 +452,7 @@ std::unique_ptr<Literal> Literal::Transpose(
// The shape with affine layout resulting from that operation will be
// F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
// most minor.
+ //
// Essentially, given MinMaj(Di) the position of the Di dimension within the
// minor to major vector, and given T(Di) the index that the original Di
// dimension has within the transposed array, a layout is affine if
@@ -536,6 +561,9 @@ string Literal::GetAsString(
}
case F16:
return tensorflow::strings::StrCat(Get<half>(multi_index));
+ case BF16:
+ return tensorflow::strings::StrCat(
+ static_cast<float>(Get<bfloat16>(multi_index)));
default:
return tensorflow::strings::StrCat(
"[", PrimitiveType_Name(shape().element_type()), "]");
@@ -569,9 +597,17 @@ int64 Literal::LinearIndex(
return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
}
-string Literal::ToString() const {
+string Literal::ToString(bool print_layout) const {
std::vector<string> pieces;
+ auto shape_to_string = [print_layout](const Shape& shape) {
+ if (print_layout) {
+ return ShapeUtil::HumanStringWithLayout(shape);
+ } else {
+ return ShapeUtil::HumanString(shape);
+ }
+ };
+
auto element_to_string =
[this](tensorflow::gtl::ArraySlice<int64> indices) -> string {
PrimitiveType element_type = shape().element_type();
@@ -585,7 +621,7 @@ string Literal::ToString() const {
// TODO(b/32894291): refactor this code to reduce code duplication.
if (ShapeUtil::IsTuple(shape())) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" (\n");
pieces.push_back(tensorflow::str_util::Join(
tuple_literals(), ",\n", [](string* out, const Literal& element) {
@@ -601,7 +637,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 2) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(" { ");
@@ -613,7 +649,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 3) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(i0 > 0 ? ",\n{" : "{");
@@ -628,7 +664,7 @@ string Literal::ToString() const {
}
pieces.push_back("\n}");
} else if (ShapeUtil::Rank(shape()) == 4) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
@@ -649,7 +685,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 5) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
@@ -676,7 +712,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {...}");
}
@@ -735,6 +771,8 @@ void* Literal::MutableInternalData() {
return reinterpret_cast<void*>(c64s_.data());
case F16:
return reinterpret_cast<void*>(f16s_.data());
+ case BF16:
+ return reinterpret_cast<void*>(bf16s_.data());
default:
LOG(FATAL) << "primitive type not supported in literals: "
<< PrimitiveType_Name(shape().element_type());
@@ -777,6 +815,9 @@ void Literal::Reserve(int64 num_elements) {
case F16:
Resize<half>(num_elements, static_cast<half>(0.0f));
break;
+ case BF16:
+ Resize<bfloat16>(num_elements, static_cast<bfloat16>(0.0f));
+ break;
default:
LOG(FATAL) << "primitive type not supported in literals: "
<< PrimitiveType_Name(shape().element_type());
@@ -816,6 +857,9 @@ tensorflow::Status Literal::ValidateLiteral() const {
case F16:
actual = f16s().size() / sizeof(half);
break;
+ case BF16:
+ actual = bf16s().size();
+ break;
default:
return tensorflow::errors::Unimplemented(
"unhandled element type for literal validation: " +
@@ -912,6 +956,7 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
CONVERT_IF_TYPES_MATCH(F16)
CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64)
+ CONVERT_IF_TYPES_MATCH(BF16)
#undef CONVERT_IF_TYPES_MATCH
case C64:
return ConvertToC64<primitive_src_type>(src_literal);
@@ -941,8 +986,9 @@ StatusOr<std::unique_ptr<Literal>> Literal::Convert(
CONVERT_IF_DEST_TYPE_MATCHES(F16)
CONVERT_IF_DEST_TYPE_MATCHES(F32)
CONVERT_IF_DEST_TYPE_MATCHES(F64)
+ CONVERT_IF_DEST_TYPE_MATCHES(BF16)
#undef CONVERT_IF_DEST_TYPE_MATCHES
- // Other types are not yet supported.
+ // Other types are not yet supported.
default:
return InvalidArgument("Unimplemented: Convert from type %s to type %s",
PrimitiveType_Name(shape().element_type()).c_str(),
@@ -1011,6 +1057,8 @@ bool Literal::operator==(const Literal& other) const {
return EqualElements<double>(*this, other, 0, &multi_index);
case F16:
return EqualElements<half>(*this, other, 0, &multi_index);
+ case BF16:
+ return EqualElements<bfloat16>(*this, other, 0, &multi_index);
case C64:
return EqualElements<complex64>(*this, other, 0, &multi_index);
default:
@@ -1120,14 +1168,19 @@ tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice() {
template <>
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() {
- // TODO - there is an endianess problem here. fix it, or wait for uint16
- // support in protobuf
auto values = mutable_f16s();
return tensorflow::gtl::MutableArraySlice<half>(values->data(),
values->size());
}
template <>
+tensorflow::gtl::MutableArraySlice<bfloat16>
+Literal::GetMutableArraySlice<bfloat16>() {
+ auto values = mutable_bf16s();
+ return {values->data(), values->size()};
+}
+
+template <>
tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const {
CHECK_EQ(shape().element_type(), PRED);
return tensorflow::gtl::ArraySlice<bool>(
@@ -1198,6 +1251,12 @@ tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const {
}
template <>
+tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const {
+ CHECK_EQ(shape().element_type(), BF16);
+ return {bf16s().data(), bf16s().size()};
+}
+
+template <>
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
const {
CHECK_EQ(shape().element_type(), C64);
@@ -1245,6 +1304,9 @@ bool Literal::IsAll(int8 value) const {
return AllElementsEqualValue<double>(*this, value);
case F16:
return AllElementsEqualValue<half>(*this, static_cast<half>(value));
+ case BF16:
+ return AllElementsEqualValue<bfloat16>(*this,
+ static_cast<bfloat16>(value));
case PRED:
if (value == 0) {
return AllElementsEqualValue<bool>(*this, false);
@@ -1266,6 +1328,9 @@ bool Literal::IsAllFloat(float value) const {
return AllElementsEqualValue<double>(*this, value);
case F16:
return AllElementsEqualValue<half>(*this, static_cast<half>(value));
+ case BF16:
+ return AllElementsEqualValue<bfloat16>(*this,
+ static_cast<bfloat16>(value));
default:
return false;
}
@@ -1302,6 +1367,8 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
return Get<complex64>(indices) == complex64(0.0f, 0.0f);
case F16:
return Get<half>(indices) == static_cast<half>(0.0f);
+ case BF16:
+ return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
case PRED:
return Get<bool>(indices) == false;
default:
@@ -1370,6 +1437,12 @@ void Literal::Resize<half>(int64 num_elements, half value) {
}
template <>
+void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value) {
+ CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
+ mutable_bf16s()->resize(num_elements, value);
+}
+
+template <>
void Literal::Resize<complex64>(int64 num_elements, complex64 value) {
CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
mutable_c64s()->resize(num_elements, value);
@@ -1417,6 +1490,19 @@ LiteralProto Literal::ToProto() const {
*proto.mutable_f16s() =
string(reinterpret_cast<const char*>(f16s_.data()),
f16s_.size() * sizeof(half));
+ if (!kLittleEndian) {
+ ConvertEndianShort(const_cast<char*>(proto.mutable_f16s()->data()),
+ proto.f16s().size());
+ }
+ break;
+ case BF16:
+ *proto.mutable_bf16s() =
+ string(reinterpret_cast<const char*>(bf16s_.data()),
+ bf16s_.size() * sizeof(bfloat16));
+ if (!kLittleEndian) {
+ ConvertEndianShort(const_cast<char*>(proto.mutable_bf16s()->data()),
+ proto.bf16s().size());
+ }
break;
case F32:
CopyToRepeatedField(proto.mutable_f32s(), f32s());
@@ -1485,6 +1571,21 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
CHECK_EQ(0, s.size() % sizeof(half));
f16s_ = std::vector<half>(s.size() / sizeof(half));
memcpy(f16s_.data(), s.data(), s.size());
+
+ if (!kLittleEndian) {
+ ConvertEndianShort(reinterpret_cast<char*>(f16s_.data()), s.size());
+ }
+ break;
+ }
+ case BF16: {
+ const string& s(literal_proto.bf16s());
+ CHECK_EQ(0, s.size() % sizeof(bfloat16));
+ bf16s_ = std::vector<bfloat16>(s.size() / sizeof(bfloat16));
+ memcpy(bf16s_.data(), s.data(), s.size());
+
+ if (!kLittleEndian) {
+ ConvertEndianShort(reinterpret_cast<char*>(bf16s_.data()), s.size());
+ }
break;
}
case F32:
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index a1e288829f..f37e529caf 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -163,6 +163,11 @@ class Literal {
const std::vector<complex64>& c64s() const { return c64s_; }
std::vector<complex64>* mutable_c64s() { return &c64s_; }
+ int bf16s_size() const { return bf16s().size(); }
+ bfloat16 bf16s(int i) const { return bf16s_[i]; }
+ const std::vector<bfloat16>& bf16s() const { return bf16s_; }
+ std::vector<bfloat16>* mutable_bf16s() { return &bf16s_; }
+
int tuple_literals_size() const { return tuple_literals().size(); }
const Literal& tuple_literals(int i) const { return tuple_literals_[i]; }
Literal* add_tuple_literals() {
@@ -450,7 +455,7 @@ class Literal {
tensorflow::Status ValidateLiteral() const;
// Returns a string representation of the literal value.
- string ToString() const;
+ string ToString(bool print_layout = false) const;
// Invokes the "per cell" callback for each element in the provided
// literal with the element's indices and a string representation of
@@ -622,6 +627,7 @@ class Literal {
std::vector<uint16> u16s_;
std::vector<uint32> u32s_;
std::vector<uint64> u64s_;
+ std::vector<bfloat16> bf16s_;
std::vector<half> f16s_;
std::vector<float> f32s_;
std::vector<double> f64s_;
@@ -675,6 +681,9 @@ template <>
tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const;
template <>
+tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const;
+
+template <>
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
const;
@@ -715,6 +724,9 @@ template <>
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice();
template <>
+tensorflow::gtl::MutableArraySlice<bfloat16> Literal::GetMutableArraySlice();
+
+template <>
tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice();
template <>
@@ -748,6 +760,9 @@ template <>
void Literal::Resize<half>(int64 num_elements, half value);
template <>
+void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value);
+
+template <>
void Literal::Resize<complex64>(int64 num_elements, complex64 value);
template <typename NativeT>
@@ -990,6 +1005,14 @@ inline half Literal::Get<half>(
return GetArraySlice<half>()[linear_index];
}
+template <>
+inline bfloat16 Literal::Get<bfloat16>(
+ tensorflow::gtl::ArraySlice<int64> multi_index) const {
+ CHECK(shape().element_type() == BF16);
+ int64 linear_index = LinearIndex(multi_index);
+ return GetArraySlice<bfloat16>()[linear_index];
+}
+
template <typename NativeT>
void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
NativeT value) {
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index 6d596da4ad..1e08101759 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -110,6 +110,18 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f});
ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString());
+
+ auto bf16_lit = Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
+ ASSERT_EQ("0.5", bf16_lit->ToString());
+
+ // 3.14 will be rounded to 3.125 in bfloat16 format (Round to nearest even).
+ auto bf16_lit_truncated =
+ Literal::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
+ ASSERT_EQ("3.140625", bf16_lit_truncated->ToString());
+
+ auto bf16_lit_truncated2 =
+ Literal::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
+ ASSERT_EQ("9", bf16_lit_truncated2->ToString());
}
TEST_F(LiteralUtilTest, LiteralVectorToString) {
@@ -397,6 +409,18 @@ TEST_F(LiteralUtilTest, IsAll) {
EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8));
EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8));
+ bfloat16 b8(8.0f);
+ bfloat16 b9(9.0f);
+
+ EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
+ EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
+ EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
+
+ // 9.001 will be truncated to 9.0
+ bfloat16 b91(9.001f);
+ bfloat16 b90(9.00f);
+ EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
+
complex64 c8_9 = {8, 9};
EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
@@ -691,6 +715,30 @@ TEST_F(LiteralUtilTest, PopulateR2C64) {
EXPECT_EQ(output, *expected);
}
+TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
+ Literal output;
+ bfloat16 h(0.25f);
+ output.PopulateWithValue<bfloat16>(h, {});
+ auto expected = Literal::CreateR0<bfloat16>(h);
+ EXPECT_EQ(output, *expected);
+}
+
+TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
+ Literal output;
+ bfloat16 h(0.5f);
+ output.PopulateWithValue<bfloat16>(h, {3});
+ auto expected = Literal::CreateR1<bfloat16>({h, h, h});
+ EXPECT_EQ(output, *expected);
+}
+
+TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
+ Literal output;
+ bfloat16 h(2.0f);
+ output.PopulateWithValue<bfloat16>(h, {2, 2});
+ auto expected = Literal::CreateR2<bfloat16>({{h, h}, {h, h}});
+ EXPECT_EQ(output, *expected);
+}
+
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output;
output.PopulateWithValue<float>(2.5f, {});
@@ -975,6 +1023,14 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{half(26.0), half(0.0), half(28.0), half(0.0)},
{half(0.0), half(31.0), half(0.0), half(33.0)}},
}}, layout_r4_dim0major_);
+ auto bf16 = Literal::CreateR4WithLayout<bfloat16>({{
+ {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
+ {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
+ {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
+ {bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}},
+ {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
+ {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
+ }}, layout_r4_dim0major_);
auto f32 = Literal::CreateR4WithLayout<float>({{
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
@@ -1008,6 +1064,12 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
conv = s8->Convert(PRED).ConsumeValueOrDie();
EXPECT_EQ(*conv, *pred);
+ conv = bf16->Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(*conv, *s32);
+
+ conv = bf16->Convert(F32).ConsumeValueOrDie();
+ EXPECT_EQ(*conv, *f32);
+
conv = pred->Convert(S32).ConsumeValueOrDie();
EXPECT_EQ(*conv, *int32_pred);
diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc
index 2113b5e06f..2bce56b7bd 100644
--- a/tensorflow/compiler/xla/primitive_util.cc
+++ b/tensorflow/compiler/xla/primitive_util.cc
@@ -79,6 +79,11 @@ PrimitiveType NativeToPrimitiveType<double>() {
}
template <>
+PrimitiveType NativeToPrimitiveType<bfloat16>() {
+ return BF16;
+}
+
+template <>
PrimitiveType NativeToPrimitiveType<half>() {
return F16;
}
@@ -89,7 +94,7 @@ PrimitiveType NativeToPrimitiveType<complex64>() {
}
bool IsFloatingPointType(PrimitiveType type) {
- return type == F16 || type == F32 || type == F64;
+ return type == F16 || type == F32 || type == F64 || type == BF16;
}
bool IsComplexType(PrimitiveType type) { return type == C64; }
@@ -118,6 +123,7 @@ int BitWidth(PrimitiveType type) {
case S16:
case U16:
case F16:
+ case BF16:
return 16;
case U32:
diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h
index a49c8b86fc..19c6a13888 100644
--- a/tensorflow/compiler/xla/primitive_util.h
+++ b/tensorflow/compiler/xla/primitive_util.h
@@ -77,6 +77,8 @@ template <>
PrimitiveType NativeToPrimitiveType<double>();
template <>
PrimitiveType NativeToPrimitiveType<half>();
+template <>
+PrimitiveType NativeToPrimitiveType<bfloat16>();
// Complex
template <>
@@ -167,6 +169,11 @@ struct PrimitiveTypeToNative<F16> {
using type = half;
};
+template <>
+struct PrimitiveTypeToNative<BF16> {
+ using type = bfloat16;
+};
+
// Complex
template <>
struct PrimitiveTypeToNative<C64> {
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 521fe411a4..cd0a316f70 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1778,7 +1778,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
)
@@ -1849,7 +1848,6 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
index 9abe30e3f3..05f2d06278 100644
--- a/tensorflow/compiler/xla/service/backend.cc
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#define EIGEN_USE_THREADS
+
#include "tensorflow/compiler/xla/service/backend.h"
#include <algorithm>
#include <string>
#include <utility>
-#define EIGEN_USE_THREADS
-
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index b422b22df9..3c5b360c8e 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -497,19 +497,19 @@ Status GatherComputationsByAllocationType(
std::vector<const HloComputation*>* global_computations) {
// Create a worklist of computations paired with whether the allocation must
// be thread-local.
- std::deque<std::pair<HloComputation*, bool>> worklist;
+ std::deque<std::pair<const HloComputation*, bool>> worklist;
worklist.push_back(std::make_pair(module->entry_computation(),
/*is_thread_local*/ false));
// Sets for quickly checking membership. Computations are returned in vectors
// for stable iteration.
- FlatSet<HloComputation*> thread_local_set;
- FlatSet<HloComputation*> global_set;
+ FlatSet<const HloComputation*> thread_local_set;
+ FlatSet<const HloComputation*> global_set;
while (!worklist.empty()) {
auto worklist_front = worklist.front();
worklist.pop_front();
- HloComputation* computation = worklist_front.first;
+ const HloComputation* computation = worklist_front.first;
bool is_thread_local = worklist_front.second;
bool in_thread_local_set = thread_local_set.count(computation) > 0;
bool in_global_set = global_set.count(computation) > 0;
@@ -653,7 +653,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
}
if (allow_input_output_aliasing_ && allocation->maybe_live_out()) {
- HloComputation* entry_computation =
+ const HloComputation* entry_computation =
assignment->module_->entry_computation();
for (auto param : entry_computation->parameter_instructions()) {
for (auto& param_buffer :
@@ -819,17 +819,6 @@ Status BufferAssigner::AssignBuffersForComputation(
continue;
}
- if (instruction->opcode() == HloOpcode::kRecv) {
- // Make sure that recv operations get a new unique allocation so that
- // don't share their buffer with any other operations.
- BufferAllocation* allocation = assignment->NewAllocation(
- *buffer, buffer_size, is_thread_local, /*is_reusable=*/false);
- allocation_indices.push_back(allocation->index());
- VLOG(3) << "New allocation #" << allocation->index()
- << " for recv: " << *buffer;
- continue;
- }
-
if (ShapeUtil::IsTuple(buffer->shape())) {
// TODO(b/34669761): Don't reuse tuple buffers because the GPU backend
// assumes longer buffer liveness than indicated by the analysis.
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 6213baee2f..4f6e69ebd4 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -280,6 +280,7 @@ cc_library(
srcs = ["dot_op_emitter.cc"],
hdrs = ["dot_op_emitter.h"],
deps = [
+ ":cpu_options",
":cpu_runtime",
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
@@ -290,8 +291,10 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/compiler/xla/service/llvm_ir:vector_support_library",
"//tensorflow/core:lib",
"@llvm//:core",
],
@@ -717,6 +720,7 @@ cc_library(
hdrs = ["cpu_options.h"],
deps = [
"//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
index dba140d112..09f028463a 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
@@ -15,11 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+
namespace {
const char* const kXlaParallelCpuOption = "xla_cpu_parallel";
const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size";
const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce";
+const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
} // namespace
@@ -45,6 +48,19 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config) {
return extra_options_map.count(kXlaOptimizeForSizeCpuOption) > 0;
}
+tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
+ const HloModuleConfig& config) {
+ const auto& extra_options_map =
+ config.debug_options().xla_backend_extra_options();
+ auto it = extra_options_map.find(kLlvmIrDotTilingFactor);
+ int64 tiling_factor;
+ if (it != extra_options_map.end() &&
+ tensorflow::strings::safe_strto64(it->second, &tiling_factor)) {
+ return tiling_factor;
+ }
+ return tensorflow::gtl::nullopt;
+}
+
} // namespace options
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h
index 5dc24ebc7b..6ba0fd2453 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h
@@ -27,6 +27,8 @@ namespace options {
bool CpuParallelBackendRequested(const HloModuleConfig& config);
bool OptimizeForSizeRequested(const HloModuleConfig& config);
bool VectorizedReduceDisabled(const HloModuleConfig& config);
+tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
+ const HloModuleConfig& config);
} // namespace options
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
index f8e260dd90..f385829cdf 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
@@ -12,15 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
+#define EIGEN_USE_THREADS
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include <memory>
#include <string>
#include <tuple>
-#define EIGEN_USE_THREADS
-
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index e57d49172b..2a447a54b0 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -25,7 +25,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
@@ -38,6 +40,450 @@ using llvm_ir::SetToFirstInsertPoint;
namespace cpu {
+namespace {
+// Loads a tile of values from a 2D tensor.
+class TileLoader {
+ public:
+ // Constructs a TileLoader that will load a tile consisting of
+ // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at
+ // `major_dim_offset` in the major dimension. The tile size along the minor
+ // dimension is the vector size, and that is implicitly determined by `vsl`.
+ TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder,
+ llvm::Value* matrix, int64 matrix_size_along_minor_dim,
+ llvm::Value* major_dim_offset, int64 tile_size_along_major_dim)
+ : vsl_(vsl) {
+ pointers_.reserve(tile_size_along_major_dim);
+ for (int64 i = 0; i < tile_size_along_major_dim; i++) {
+ llvm::Value* total_offset = ir_builder->CreateMul(
+ ir_builder->getInt64(matrix_size_along_minor_dim),
+ ir_builder->CreateAdd(ir_builder->getInt64(i), major_dim_offset));
+ pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset));
+ }
+ }
+
+ // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at
+ // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the
+ // minor dimension.
+ std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const {
+ std::vector<llvm::Value*> result;
+ result.reserve(pointers_.size());
+ for (const auto& pointer : pointers_) {
+ result.push_back(vsl_->LoadVector(pointer, minor_dim_offset));
+ }
+ return result;
+ }
+
+ private:
+ VectorSupportLibrary* vsl_;
+ std::vector<llvm::Value*> pointers_;
+};
+
+// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the
+// layout of the vector does not matter). This implementation uses a tiling
+// scheme to improve performance.
+//
+// We logically separate the LHS matrix into four segments:
+//
+// +----------------------+---+
+// | | |
+// | | |
+// | A | B |
+// | | |
+// | | |
+// | | |
+// +----------------------+---+
+// | C | D |
+// +----------------------+---+
+//
+// where A is the largest submatrix of the LHS that can be evenly dividied into
+// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
+//
+// +---+---+---+---+ +--+--+--+--+
+// |M00|M10|M20|M30| |V0|V1|V2|V3|
+// +---+---+---+---+ +--+--+--+--+
+// |M01|M11|M21|M31| and |V0|V1|V2|V3|
+// +---+---+---+---+ +--+--+--+--+
+// |M02|M12|M22|M32| |V0|V1|V2|V3|
+// +---+---+---+---+ +--+--+--+--+
+// |M03|M13|M23|M33| |V0|V1|V2|V3|
+// +---+---+---+---+ +--+--+--+--+
+//
+// (Legend: rows are horizontal and columns are vertical; and each column is one
+// llvm::Value of a vector type)
+//
+// where:
+//
+// a. The left tile is from the column major left matrix.
+// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3]
+// vector loaded from the RHS vector.
+//
+// As we iterate through the column dimension, we compute the change to the
+// result vector by an elementwise multiplication between the two tiles above
+// followed by a reduction along the major dimension:
+//
+// +-----------------------------------+
+// | M00*V0 + M10*V1 + M20*V2 + M30*V3 |
+// +-----------------------------------+
+// | M01*V0 + M11*V1 + M21*V2 + M31*V3 |
+// Result[R:R+4] += +-----------------------------------+
+// | M02*V0 + M12*V1 + M22*V2 + M32*V3 |
+// +-----------------------------------+
+// | M03*V0 + M13*V1 + M23*V2 + M33*V3 |
+// +-----------------------------------+
+//
+// Where R is the starting row for the tile.
+//
+// We have an inner epilogue loop to deal with the "C" submatrix and an outer
+// epilogue loop to deal with the B,D submarix.
+//
+// TODO(sanjoy): We should investigate if using gather loads and scatter stores
+// can be used here have the same inner loop for both column-major and row-major
+// matrix-vector products.
+class ColumnMajorMatrixVectorProductEmitter {
+ public:
+ ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type,
+ int64 tile_rows, int64 tile_cols,
+ int64 m, int64 k, llvm::Value* lhs,
+ llvm::Value* rhs, llvm::Value* result,
+ llvm::IRBuilder<>* ir_builder)
+ : scalar_type_(scalar_type),
+ tile_rows_(tile_rows),
+ tile_cols_(tile_cols),
+ m_(m),
+ k_(k),
+ lhs_(lhs),
+ rhs_(rhs),
+ result_(result),
+ ir_builder_(ir_builder),
+ ksl_(ir_builder_),
+ vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") {
+ CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows_)));
+ }
+
+ void Emit();
+
+ private:
+ void EmitOuterLoopBody(llvm::Value* column, int64 column_count,
+ bool is_first_column);
+
+ TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) {
+ return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
+ /*matrix_size_along_minor_dim=*/m_,
+ /*major_dim_offset=*/column_start,
+ /*tile_size_along_major_dim=*/column_count);
+ }
+
+ // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous
+ // sequnce of `count` values, each one broadcasted to the vector width.
+ std::vector<llvm::Value*> LoadRhsTile(llvm::Value* offset, int64 count) {
+ llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset);
+ std::vector<llvm::Value*> result;
+ result.reserve(count);
+ for (int64 i = 0; i < count; i++) {
+ result.push_back(vsl_.LoadBroadcast(base_pointer, i));
+ }
+ return result;
+ }
+
+ void EmitInnerLoopTiled(TileLoader* lhs_tile_loader,
+ const std::vector<llvm::Value*>& rhs_tile,
+ int64 columns, bool is_first_column);
+
+ void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns,
+ bool is_first_tiled_column);
+
+ PrimitiveType scalar_type_;
+ int64 tile_rows_;
+ int64 tile_cols_;
+ int64 m_;
+ int64 k_;
+ llvm::Value* lhs_;
+ llvm::Value* rhs_;
+ llvm::Value* result_;
+ llvm::IRBuilder<>* ir_builder_;
+ KernelSupportLibrary ksl_;
+ VectorSupportLibrary vsl_;
+};
+
+void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody(
+ llvm::Value* column, int64 column_count, bool is_first_column) {
+ TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column,
+ /*column_count=*/column_count);
+
+ std::vector<llvm::Value*> rhs_tile =
+ LoadRhsTile(column, /*count=*/column_count);
+ EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile,
+ /*columns=*/column_count, is_first_column);
+ EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column);
+}
+
+void ColumnMajorMatrixVectorProductEmitter::Emit() {
+ // See the comment on the class declaration for the algorithm used here.
+ int64 column_remainder = k_ % tile_cols_;
+ int64 column_limit = k_ - column_remainder;
+
+ ksl_.For("dot.outer.tiled",
+ /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_,
+ [&](llvm::Value* column, bool is_first_column) {
+ EmitOuterLoopBody(column, tile_cols_, is_first_column);
+ });
+
+ if (column_remainder != 0) {
+ EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder,
+ column_limit == 0);
+ }
+}
+
+void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
+ TileLoader* lhs_tile_loader, const std::vector<llvm::Value*>& rhs_tile,
+ int64 columns, bool is_first_column) {
+ int64 row_limit = m_ - (m_ % tile_rows_);
+
+ ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
+ /*step=*/tile_rows_, [&](llvm::Value* row) {
+ std::vector<llvm::Value*> lhs_tile =
+ lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row);
+ llvm::Value* accumulator = is_first_column
+ ? vsl_.GetZeroVector()
+ : vsl_.LoadVector(result_, row);
+ for (int i = 0; i < columns; i++) {
+ accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator);
+ }
+ vsl_.StoreVector(accumulator, result_, row);
+ });
+}
+
+void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
+ llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) {
+ int64 row_start = m_ - (m_ % tile_rows_);
+ if (row_start == m_) {
+ return;
+ }
+
+ llvm::Value* columns_llvm = ir_builder_->getInt64(columns);
+
+ // for (col = current_tile_col; col < (columns + current_tile_col); col++)
+ // for (row = row_start, row < m_; row++) {
+ // result[row] += lhs[row, col] * rhs[col]
+ // // Also take into account that if col is 0 then result[row] is not
+ // // initialized.
+ // }
+
+ ksl_.For(
+ "dot.inner.epilg.outer", /*start=*/current_tile_col,
+ /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col),
+ /*step=*/1, /*peel_first_iteration=*/false,
+ [&](llvm::Value* col, llvm::Value* is_first_scalar_col) {
+ llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col);
+ llvm::Value* total_offset =
+ ir_builder_->CreateMul(col, ir_builder_->getInt64(m_));
+ llvm::Value* lhs_base_pointer =
+ vsl_.ComputeOffsetPointer(lhs_, total_offset);
+ ksl_.For(
+ "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_,
+ /*step=*/1, [&](llvm::Value* scalar_row) {
+ llvm::Value* product = vsl_.Mul(
+ vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element);
+ llvm::Value* setting_result_first_time = ir_builder_->CreateAnd(
+ is_first_scalar_col,
+ ir_builder_->getInt1(is_first_tiled_column));
+ ksl_.If(
+ setting_result_first_time,
+ [&]() { vsl_.StoreScalar(product, result_, scalar_row); },
+ [&]() {
+ vsl_.StoreScalar(
+ vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product),
+ result_, scalar_row);
+ });
+ });
+ });
+}
+
+// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the
+// layout of the vector does not matter). This implementation uses a tiling
+// scheme to improve performance.
+//
+// We logically separate the LHS matrix into four segments:
+//
+// +----------------------+---+
+// | | |
+// | | |
+// | A | B |
+// | | |
+// | | |
+// | | |
+// +----------------------+---+
+// | C | D |
+// +----------------------+---+
+//
+// where A is the largest submatrix of the LHS that can be evenly dividied into
+// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
+//
+// +---+---+---+---+
+// |M00|M10|M20|M30|
+// +---+---+---+---+ +--+--+--+--+
+// |M01|M11|M21|M31| and |V0|V1|V2|V3|
+// +---+---+---+---+ +--+--+--+--+
+// |M02|M12|M22|M32|
+// +---+---+---+---+
+// |M03|M13|M23|M33|
+// +---+---+---+---+
+//
+// (Legend: rows are horizontal and columns are vertical; and each row is one
+// llvm::Value of a vector type)
+//
+// where:
+//
+// a. The left tile is loaded from the row major left matrix.
+// b. The right vector is loaded from the RHS vector.
+//
+// We keep 4 vector accumulators accumulating the following four vector
+// expressions as we iterate over the row dimension:
+//
+// +------+------+------+------+
+// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4)
+// +------+------+------+------+
+//
+// In the end we do a horizontal reduction over these 4 vector accumulators to
+// get 4 values in the result vector.
+//
+// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer
+// epilogue loop to deal with the C,D submatrix.
+class RowMajorMatrixVectorProductEmitter {
+ public:
+ RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows,
+ int64 tile_cols, int64 m, int64 k,
+ llvm::Value* lhs, llvm::Value* rhs,
+ llvm::Value* result,
+ llvm::IRBuilder<>* ir_builder)
+ : scalar_type_(scalar_type),
+ tile_rows_(tile_rows),
+ tile_cols_(tile_cols),
+ m_(m),
+ k_(k),
+ lhs_(lhs),
+ rhs_(rhs),
+ result_(result),
+ ir_builder_(ir_builder),
+ ksl_(ir_builder_),
+ vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") {
+ CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols_)));
+ }
+
+ void Emit();
+
+ private:
+ TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) {
+ return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
+ /*matrix_size_along_minor_dim=*/k_,
+ /*major_dim_offset=*/row_start,
+ /*tile_size_along_major_dim=*/row_count);
+ }
+
+ void EmitOuterLoopBody(llvm::Value* row, int64 row_count);
+
+ void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows,
+ std::vector<VectorVariable>* vector_accumulators);
+
+ void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows,
+ std::vector<ScalarVariable>* scalar_accumulators);
+
+ PrimitiveType scalar_type_;
+ int64 tile_rows_;
+ int64 tile_cols_;
+ int64 m_;
+ int64 k_;
+ llvm::Value* lhs_;
+ llvm::Value* rhs_;
+ llvm::Value* result_;
+ llvm::IRBuilder<>* ir_builder_;
+ KernelSupportLibrary ksl_;
+ VectorSupportLibrary vsl_;
+};
+
+void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
+ int64 row_count) {
+ TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row,
+ /*row_count=*/row_count);
+ std::vector<VectorVariable> vector_accumulators;
+ std::vector<ScalarVariable> scalar_accumulators;
+ for (int i = 0; i < row_count; i++) {
+ vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector());
+ scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar());
+ }
+ EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count,
+ &vector_accumulators);
+ EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count,
+ &scalar_accumulators);
+
+ for (int i = 0; i < row_count; i++) {
+ llvm::Value* result_value =
+ vsl_.Add(vsl_.AddReduce(vector_accumulators[i].Get()),
+ scalar_accumulators[i].Get());
+ llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row);
+ vsl_.StoreScalar(result_value, result_, offset);
+ }
+}
+
+void RowMajorMatrixVectorProductEmitter::Emit() {
+ // See the comment on the class declaration for the algorithm used here.
+ int64 row_remainder = m_ % tile_rows_;
+ int64 row_limit = m_ - row_remainder;
+
+ ksl_.For("dot.outer.tiled",
+ /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_,
+ [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); });
+
+ if (row_remainder != 0) {
+ EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder);
+ }
+}
+
+void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
+ TileLoader* lhs_tile_loader, int64 rows,
+ std::vector<VectorVariable>* vector_accumulators) {
+ int64 column_limit = k_ - (k_ % tile_cols_);
+
+ ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
+ /*step=*/tile_cols_, [&](llvm::Value* col) {
+ std::vector<llvm::Value*> lhs_tile =
+ lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col);
+ llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
+ for (int i = 0; i < rows; i++) {
+ llvm::Value* old_sum = (*vector_accumulators)[i].Get();
+ (*vector_accumulators)[i].Set(
+ vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i])));
+ }
+ });
+}
+
+void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
+ llvm::Value* current_tile_row, int64 rows,
+ std::vector<ScalarVariable>* scalar_accumulators) {
+ int64 column_start = k_ - (k_ % tile_cols_);
+ if (column_start == k_) {
+ return;
+ }
+
+ for (int r = 0; r < rows; r++) {
+ llvm::Value* total_offset = ir_builder_->CreateMul(
+ ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row),
+ ir_builder_->getInt64(k_));
+ llvm::Value* lhs_base_pointer =
+ vsl_.ComputeOffsetPointer(lhs_, total_offset);
+ ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_,
+ /*step=*/1, [&](llvm::Value* scalar_col) {
+ llvm::Value* product =
+ vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
+ vsl_.LoadScalar(rhs_, scalar_col));
+ llvm::Value* old_value = (*scalar_accumulators)[r].Get();
+ (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product));
+ });
+ }
+}
+
+} // namespace
+
DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
bool transpose_rhs,
const llvm_ir::IrArray& target_array,
@@ -72,6 +518,93 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; }
+bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
+ if (dot_.shape().dimensions_size() != 2 ||
+ ProfitableToImplementDotInUntiledLlvmIr(dot_) ==
+ DotInLlvmIrProfitable::kYes) {
+ return false;
+ }
+
+ if (!primitive_util::IsFloatingPointType(dot_.shape().element_type()) &&
+ !primitive_util::IsIntegralType(dot_.shape().element_type())) {
+ return false;
+ }
+
+ MatMultDims mat_mult_dims = GetMatMultDims();
+ bool is_column_major_matrix_vector = false;
+ bool is_row_major_matrix_vector = false;
+
+ int64 m, k;
+ bool swap_operands;
+
+ if (mat_mult_dims.m == 1) {
+ bool rhs_effectively_row_major =
+ transpose_rhs_ ^ !mat_mult_dims.rhs_column_major;
+ if (rhs_effectively_row_major) {
+ k = mat_mult_dims.k;
+ m = mat_mult_dims.n;
+ is_column_major_matrix_vector = true;
+ swap_operands = true;
+ } else {
+ k = mat_mult_dims.k;
+ m = mat_mult_dims.n;
+ is_row_major_matrix_vector = true;
+ swap_operands = true;
+ }
+ }
+
+ if (mat_mult_dims.n == 1) {
+ bool lhs_effectively_column_major =
+ transpose_lhs_ ^ mat_mult_dims.lhs_column_major;
+ if (lhs_effectively_column_major) {
+ m = mat_mult_dims.m;
+ k = mat_mult_dims.k;
+ is_column_major_matrix_vector = true;
+ swap_operands = false;
+ } else {
+ m = mat_mult_dims.m;
+ k = mat_mult_dims.k;
+ is_row_major_matrix_vector = true;
+ swap_operands = false;
+ }
+ }
+
+ if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) {
+ return false;
+ }
+
+ int64 tiling_factor = GetGemvTilingFactor();
+ CHECK_GT(tiling_factor, 0);
+
+ if (is_column_major_matrix_vector) {
+ VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
+ << " and k = " << k;
+ ColumnMajorMatrixVectorProductEmitter emitter(
+ dot_.shape().element_type(), /*tile_rows=*/8,
+ /*tile_cols=*/tiling_factor, m, k,
+ swap_operands ? rhs_array_.GetBasePointer()
+ : lhs_array_.GetBasePointer(),
+ swap_operands ? lhs_array_.GetBasePointer()
+ : rhs_array_.GetBasePointer(),
+ target_array_.GetBasePointer(), ir_builder_);
+ emitter.Emit();
+ } else {
+ VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
+ << " and k = " << k;
+ RowMajorMatrixVectorProductEmitter emitter(
+ dot_.shape().element_type(), /*tile_rows=*/tiling_factor,
+ /*tile_cols=*/8, m, k,
+ swap_operands ? rhs_array_.GetBasePointer()
+ : lhs_array_.GetBasePointer(),
+ swap_operands ? lhs_array_.GetBasePointer()
+ : rhs_array_.GetBasePointer(),
+ target_array_.GetBasePointer(), ir_builder_);
+ emitter.Emit();
+ }
+
+ return true;
+}
+
tensorflow::Status DotOpEmitter::Emit() {
// The dot operation performs a sum of products over dimension 0 of the left
// hand side operand and dimension 1 of the right hand side operand.
@@ -105,6 +638,10 @@ tensorflow::Status DotOpEmitter::Emit() {
return EmitScalarDot();
}
+ if (EmitLlvmIrDotIfProfitable()) {
+ return Status::OK();
+ }
+
if (PotentiallyImplementedAsEigenDot(dot_)) {
return EmitCallToRuntime();
}
@@ -340,22 +877,17 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
//
// Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
- const Shape& lhs_shape = lhs_array_.GetShape();
- const Shape& rhs_shape = rhs_array_.GetShape();
+ MatMultDims mat_mult_dims = GetMatMultDims();
- CHECK(LayoutUtil::Equal(lhs_shape.layout(), rhs_shape.layout()));
+ CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
- int64 m = lhs_shape.dimensions(transpose_lhs_ ? 1 : 0);
- int64 k = lhs_shape.dimensions(transpose_lhs_ ? 0 : 1);
- int64 n = rhs_shape.dimensions(transpose_rhs_ ? 0 : 1);
const llvm_ir::IrArray* lhs = &lhs_array_;
const llvm_ir::IrArray* rhs = &rhs_array_;
bool transpose_lhs = transpose_lhs_;
bool transpose_rhs = transpose_rhs_;
- bool is_column_major = lhs_shape.layout().minor_to_major(0) == 0;
- if (!is_column_major) {
- std::swap(m, n);
+ if (!mat_mult_dims.lhs_column_major) {
+ std::swap(mat_mult_dims.m, mat_mult_dims.n);
std::swap(lhs, rhs);
std::swap(transpose_lhs, transpose_rhs);
}
@@ -367,12 +899,27 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
float_ptr_type),
ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
- ir_builder_->getInt64(m), ir_builder_->getInt64(n),
- ir_builder_->getInt64(k), ir_builder_->getInt32(transpose_lhs),
+ ir_builder_->getInt64(mat_mult_dims.m),
+ ir_builder_->getInt64(mat_mult_dims.n),
+ ir_builder_->getInt64(mat_mult_dims.k),
+ ir_builder_->getInt32(transpose_lhs),
ir_builder_->getInt32(transpose_rhs)});
return tensorflow::Status::OK();
}
+DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
+ CHECK_EQ(dot_.shape().dimensions_size(), 2);
+
+ const Shape& lhs_shape = lhs_array_.GetShape();
+ const Shape& rhs_shape = rhs_array_.GetShape();
+
+ return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0),
+ lhs_shape.dimensions(transpose_lhs_ ? 0 : 1),
+ rhs_shape.dimensions(transpose_rhs_ ? 0 : 1),
+ lhs_shape.layout().minor_to_major(0) == 0,
+ rhs_shape.layout().minor_to_major(0) == 0};
+}
+
llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest(
llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array,
int64 reduction_dimension, tensorflow::StringPiece name_suffix) {
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index cfc1066045..470bf6ffb4 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_
#include "llvm/IR/IRBuilder.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
@@ -59,6 +60,10 @@ class DotOpEmitter {
// LHS and RHS) and store the results in the target.
tensorflow::Status EmitScalarDot();
+ // Emit an LLVM IR implementation of the dot operation if we can. Returns
+ // true if an LLVM IR implementation was emitted.
+ bool EmitLlvmIrDotIfProfitable();
+
// Emits a call to the CPU runtime to perform the matrix multiply.
tensorflow::Status EmitCallToRuntime();
@@ -77,6 +82,38 @@ class DotOpEmitter {
// no padding, and a rank of two.
bool ShapesAreLegalForRuntimeDot() const;
+ // Represents the dimensions of a matrix-matrix multiply operation.
+ struct MatMultDims {
+ // The number of rows in the LHS.
+ int64 m;
+
+ // The number of columns in the LHS, which is also must be equal to the
+ // number of rows in the RHS.
+ int64 k;
+
+ // The number of columns on the RHS.
+ int64 n;
+
+ // True if the LHS matrix column major.
+ bool lhs_column_major;
+
+ // True if the RHS matrix column major.
+ bool rhs_column_major;
+ };
+
+ // Get the MatMultDims instance for the dot product this DotOpEmitter
+ // represents. Precondition: the dot is of rank 2 (and thus its operands are
+ // of rank 2 as well).
+ MatMultDims GetMatMultDims() const;
+
+ // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
+ // registers.
+ int64 GetGemvTilingFactor() const {
+ const int64 kDefaultTilingFactor = 8;
+ return options::LlvmIrGemvTilingFactor(hlo_module_config_)
+ .value_or(kDefaultTilingFactor);
+ }
+
const HloInstruction& dot_;
const bool transpose_lhs_;
const bool transpose_rhs_;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
index b99b36a55e..7149a19310 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
@@ -105,7 +105,9 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
return false;
}
- if (ProfitableToImplementDotInLlvmIr(hlo) == DotInLlvmIrProfitable::kYes) {
+ if (ProfitableToImplementDotInUntiledLlvmIr(hlo) ==
+ DotInLlvmIrProfitable::kYes ||
+ ProfitableToImplementDotInTiledLlvmIr(hlo)) {
return false;
}
@@ -136,7 +138,7 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
return false;
}
-DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr(
+DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
const HloInstruction& dot) {
if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) {
const Shape& result_shape = dot.shape();
@@ -178,5 +180,16 @@ DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr(
return DotInLlvmIrProfitable::kNo;
}
+bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) {
+ // Any Matrix-Vector product of floating point or integral type, or
+ // a transpose-dot fusion of the same can be lowered to a tiled LLVM
+ // IR implementation.
+ const Shape& shape = dot.shape();
+ return shape.dimensions_size() == 2 &&
+ (shape.dimensions(0) == 1 || shape.dimensions(1) == 1) &&
+ (primitive_util::IsFloatingPointType(shape.element_type()) ||
+ primitive_util::IsIntegralType(shape.element_type()));
+}
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h
index 66656ed997..cbe07a7c2b 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h
@@ -29,16 +29,21 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& dot);
enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs };
// Returns a value to indicate if (and under what conditions) will lowering
-// |dot| as a pure LLVM IR dot operation be profitable over calling into Eigen.
-// Possible return values are:
+// |dot| as a untiled LLVM IR dot operation be profitable over calling into
+// Eigen or emitting a tiled LLVM IR implementation. Possible return values
+// are:
//
// * DotInLlvmIrProfitable::kYes - always profitable.
// * DotInLlvmIrProfitable::kNo - never profitable.
// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make
// the Rhs layout column major.
-DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr(
+DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
const HloInstruction& dot);
+// Returns true to indicate that we can generate a tiled LLVM IR implementation
+// for |dot|.
+bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot);
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index a20ce6826c..e547f291b8 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -1983,6 +1983,11 @@ Status IrEmitter::HandleSend(HloInstruction* send) {
return Unimplemented("Send is not implemented on CPU. See b/33942983.");
}
+Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
+ // TODO(b/33942983): Support Send/Recv on CPU.
+ return Unimplemented("Send-done is not implemented on CPU. See b/33942983.");
+}
+
Status IrEmitter::HandleSlice(HloInstruction* slice) {
VLOG(2) << "HandleSlice: " << slice->ToString();
auto operand = slice->operand(0);
@@ -2148,6 +2153,11 @@ Status IrEmitter::HandleRecv(HloInstruction* recv) {
return Unimplemented("Recv is not implemented on CPU. See b/33942983.");
}
+Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
+ // TODO(b/33942983): Support Send/Recv on CPU.
+ return Unimplemented("Recv-done is not implemented on CPU. See b/33942983.");
+}
+
Status IrEmitter::HandlePad(HloInstruction* pad) {
// CPU backend does not properly handle negative padding but this is ok
// because negative padding should be removed by the algebraic simplifier.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 5d061e11e3..83eded5ad8 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -171,11 +171,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleReduceWindow(HloInstruction* reduce_window) override;
Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override;
Status HandleSend(HloInstruction* send) override;
+ Status HandleSendDone(HloInstruction* send_done) override;
Status HandleSlice(HloInstruction* slice) override;
Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
Status HandleDynamicUpdateSlice(
HloInstruction* dynamic_update_slice) override;
Status HandleRecv(HloInstruction* recv) override;
+ Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandlePad(HloInstruction* pad) override;
Status HandleTuple(HloInstruction* tuple) override;
Status HandleMap(HloInstruction* map) override;
diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc
index c446b6b792..b75ca34e0a 100644
--- a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc
@@ -51,7 +51,7 @@ Status CpuLayoutAssignment::AddBackendConstraints(
tensorflow::gtl::FlatMap<const HloInstruction*, bool>
should_make_rhs_col_major_cache;
auto should_make_rhs_col_major = [&](const HloInstruction& instruction) {
- if (ProfitableToImplementDotInLlvmIr(instruction) !=
+ if (ProfitableToImplementDotInUntiledLlvmIr(instruction) !=
DotInLlvmIrProfitable::kWithColumnMajorRhs) {
return false;
}
@@ -68,7 +68,7 @@ Status CpuLayoutAssignment::AddBackendConstraints(
bool result = std::all_of(
rhs->users().begin(), rhs->users().end(), [&](HloInstruction* user) {
- return ProfitableToImplementDotInLlvmIr(*user) ==
+ return ProfitableToImplementDotInUntiledLlvmIr(*user) ==
DotInLlvmIrProfitable::kWithColumnMajorRhs &&
user->operand(0) != rhs;
});
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index de3cd15440..bc73839a88 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -211,9 +211,11 @@ class DfsHloVisitorBase {
virtual Status HandlePad(HloInstructionPtr hlo) = 0;
- virtual Status HandleSend(HloInstructionPtr hlo) = 0;
+ virtual Status HandleSend(HloInstructionPtr send) = 0;
+ virtual Status HandleSendDone(HloInstructionPtr send_done) = 0;
- virtual Status HandleRecv(HloInstructionPtr hlo) = 0;
+ virtual Status HandleRecv(HloInstructionPtr recv) = 0;
+ virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0;
virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 7ce88be89d..5415bab5b3 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -167,11 +167,17 @@ class DfsHloVisitorWithDefaultBase
Status HandleWhile(HloInstructionPtr xla_while) override {
return DefaultAction(xla_while);
}
+ Status HandleRecv(HloInstructionPtr recv) override {
+ return DefaultAction(recv);
+ }
+ Status HandleRecvDone(HloInstructionPtr recv_done) override {
+ return DefaultAction(recv_done);
+ }
Status HandleSend(HloInstructionPtr send) override {
return DefaultAction(send);
}
- Status HandleRecv(HloInstructionPtr recv) override {
- return DefaultAction(recv);
+ Status HandleSendDone(HloInstructionPtr send_done) override {
+ return DefaultAction(send_done);
}
// Invoked to inform the visitor that the traversal has completed, and that
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 536b96dcf6..e79d0a4c79 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -279,6 +280,13 @@ std::vector<AlgorithmDesc> ConvolutionThunk::GetAlgorithms(
return algorithms;
}
+static string AlgorithmToString(const se::dnn::AlgorithmDesc& algo) {
+ if (algo.tensor_ops_enabled()) {
+ return tensorflow::strings::StrCat(algo.algo_id(), "+TC");
+ }
+ return tensorflow::strings::StrCat(algo.algo_id());
+}
+
tensorflow::Status ConvolutionThunk::ConvolveWithTune(
const BatchDescriptor& input_descriptor, se::DeviceMemory<float> input_data,
const FilterDescriptor& filter_descriptor,
@@ -303,6 +311,8 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
buffer_allocations.device_ordinal(),
buffer_allocations.memory_allocator());
se::dnn::ProfileResult profile_result;
+ VLOG(3) << "Trying algorithm " << AlgorithmToString(algorithm)
+ << " for ConvolutionThunk: " << this;
bool launch_ok =
Convolve(input_descriptor, input_data, filter_descriptor, filter_data,
output_descriptor, output_data, convolution_descriptor,
@@ -310,6 +320,11 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
&scratch_allocator, &profile_result)
.ok();
if (launch_ok && profile_result.is_valid()) {
+ VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm)
+ << " for ConvolutionThunk " << this << " succeeded, taking "
+ << profile_result.elapsed_time_in_ms()
+ << "ms. (Best result: " << best_result.elapsed_time_in_ms()
+ << "ms)";
if (profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
best_result = profile_result;
@@ -319,6 +334,9 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
best_result_without_scratch.elapsed_time_in_ms()) {
best_result_without_scratch = profile_result;
}
+ } else {
+ VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm)
+ << " for ConvolutionThunk " << this << " failed.";
}
}
@@ -343,8 +361,8 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
{
VLOG(2) << "Using convolution algorithm ("
- << best_algorithm_.algorithm().algo_id() << ", "
- << best_algorithm_.algorithm_no_scratch().algo_id()
+ << AlgorithmToString(best_algorithm_.algorithm()) << ", "
+ << AlgorithmToString(best_algorithm_.algorithm_no_scratch())
<< ") for ConvolutionThunk: " << this;
ConvolveScratchAllocator scratch_allocator(
buffer_allocations.device_ordinal(),
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index ceb0e530c1..b77f75ff79 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -75,6 +75,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/subprocess.h"
+#include "tensorflow/core/platform/tracing.h"
namespace se = ::perftools::gputools;
@@ -87,6 +88,7 @@ namespace gpu {
namespace {
+using tensorflow::port::Tracing;
using tensorflow::strings::StrCat;
// Any address of a variable residing in global memory or returned by one of the
@@ -231,6 +233,7 @@ tensorflow::Status PrepareHloModuleForIrEmitting(
// code (i.e. a cubin) as a byte array.
StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major,
int cc_minor) {
+ Tracing::TraceMe annotation("Compile PTX", /*is_expensive=*/true);
const string ptxas_path =
tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas");
VLOG(2) << "Using ptxas at " << ptxas_path;
@@ -295,11 +298,15 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec) {
TF_RET_CHECK(stream_exec != nullptr);
- TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(),
- stream_exec->GetDeviceDescription(),
- ShapeSizeBytesFunction()));
- TF_RETURN_IF_ERROR(
- PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction()));
+ {
+ Tracing::TraceMe annotation("HLO Transforms", module->name(),
+ /*is_expensive=*/true);
+ TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(),
+ stream_exec->GetDeviceDescription(),
+ ShapeSizeBytesFunction()));
+ TF_RETURN_IF_ERROR(
+ PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction()));
+ }
llvm::LLVMContext llvm_context;
std::string buffer;
@@ -421,6 +428,22 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
VLOG(2) << "PTX:";
XLA_VLOG_LINES(2, ptx);
+ // Write PTX to IR dump directory, if IR dumping was requested.
+ if (!ir_dump_directory.empty()) {
+ const string ptx_outfile = tensorflow::io::JoinPath(
+ ir_dump_directory, StrCat(module->name(), ".ptx"));
+ auto status = [&] {
+ auto* env = tensorflow::Env::Default();
+ TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory));
+ TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, ptx_outfile, ptx));
+ return Status::OK();
+ }();
+ if (!status.ok()) {
+ LOG(WARNING) << "Couldn't dump PTX for module " << module->name()
+ << " to " << ptx_outfile << ": " << status;
+ }
+ }
+
const std::vector<uint8> cubin =
CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor);
@@ -444,6 +467,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
std::vector<uint8> GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx,
int cc_major,
int cc_minor) {
+ Tracing::TraceMe annotation("PTX->CUBIN", /*is_expensive=*/true);
bool inserted;
decltype(compilation_cache_.begin()) iter;
// Pointers into compilation_cache_ where the ptx and (optional) cubin are
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 57a3f713e3..9d55c7859d 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -128,10 +128,18 @@ Status IrEmitter::HandleSend(HloInstruction*) {
return Unimplemented("Send is not implemented on GPU");
}
+Status IrEmitter::HandleSendDone(HloInstruction*) {
+ return Unimplemented("Send-Done is not implemented on GPU");
+}
+
Status IrEmitter::HandleRecv(HloInstruction*) {
return Unimplemented("Recv is not implemented on GPU");
}
+Status IrEmitter::HandleRecvDone(HloInstruction*) {
+ return Unimplemented("Recv-done is not implemented on GPU");
+}
+
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
std::vector<llvm::Value*> base_ptrs;
for (const HloInstruction* operand : tuple->operands()) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 263992d925..61fdeaa0ee 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -84,7 +84,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleOutfeed(HloInstruction* outfeed) override;
Status HandleSort(HloInstruction* sort) override;
Status HandleSend(HloInstruction* send) override;
+ Status HandleSendDone(HloInstruction* send_done) override;
Status HandleRecv(HloInstruction* recv) override;
+ Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleParameter(HloInstruction* parameter) override;
Status HandleReduce(HloInstruction* reduce) override;
Status HandleTuple(HloInstruction* tuple) override;
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
index 817e95a31c..1cb963be61 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
@@ -60,6 +60,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/tracing.h"
namespace xla {
namespace gpu {
@@ -488,6 +489,9 @@ StatusOr<string> CompileToPtx(llvm::Module* module,
string ptx;
{
+ tensorflow::port::Tracing::TraceMe annotation(
+ "Compiling IR", llvm_ir::AsString(module->getName()),
+ /*is_expensive=*/true);
ScopedLoggingTimer compilation_timer(
"Compile module " + llvm_ir::AsString(module->getName()),
/*vlog_level=*/2);
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 17ba2b673a..1877065f67 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -337,10 +337,18 @@ Status HloCostAnalysis::HandleSend(const HloInstruction*) {
return Status::OK();
}
+Status HloCostAnalysis::HandleSendDone(const HloInstruction*) {
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleRecv(const HloInstruction*) {
return Status::OK();
}
+Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) {
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 8074868e37..0f44775378 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -60,7 +60,9 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleReducePrecision(const HloInstruction* hlo) override;
Status HandleConcatenate(const HloInstruction* concatenate) override;
Status HandleSend(const HloInstruction* send) override;
+ Status HandleSendDone(const HloInstruction* send_done) override;
Status HandleRecv(const HloInstruction* recv) override;
+ Status HandleRecvDone(const HloInstruction* recv_done) override;
Status HandleConvert(const HloInstruction* convert) override;
Status HandleCopy(const HloInstruction* copy) override;
Status HandleDot(const HloInstruction* dot) override;
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 7c4626e78a..3601a790c4 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -79,12 +79,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
// Test that two identical constants with different layouts are commoned if
// the pass is not layout sensitive.
auto builder = HloComputation::Builder(TestName());
- auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- /*minor_to_major=*/{0, 1})));
- auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- /*minor_to_major=*/{1, 0})));
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
@@ -111,12 +111,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
// Test that two identical constants with different layouts are *not* commoned
// if the pass is layout sensitive.
auto builder = HloComputation::Builder(TestName());
- auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- /*minor_to_major=*/{0, 1})));
- auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- /*minor_to_major=*/{1, 0})));
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 92261bce62..ff80f18bb5 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -242,6 +242,51 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
return false;
}
+bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
+ CHECK_EQ(send->opcode(), HloOpcode::kSend);
+ bool changed = false;
+ // Send forwards the operand value to the output tuple at {0}.
+ for (auto& pair : GetInstructionValueSet(send->operand(0))) {
+ const ShapeIndex& operand_index = pair.first;
+ const HloValueSet& operand_value_set = pair.second;
+
+ ShapeIndex index = {0};
+ for (int64 i : operand_index) {
+ index.push_back(i);
+ }
+
+ HloValueSet& value_set = GetValueSet(send, index);
+ if (value_set != operand_value_set) {
+ value_set = operand_value_set;
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
+ CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
+ bool changed = false;
+ // RecvDone forwards the operand value at {0} to the output.
+ for (auto& pair : GetInstructionValueSet(recv_done)) {
+ ShapeIndex& index = pair.first;
+ HloValueSet& value_set = pair.second;
+
+ ShapeIndex operand_index = {0};
+ for (int64 i : index) {
+ operand_index.push_back(i);
+ }
+
+ const HloValueSet& operand_value_set =
+ GetValueSet(recv_done->operand(0), operand_index);
+ if (value_set != operand_value_set) {
+ value_set = operand_value_set;
+ changed = true;
+ }
+ }
+ return changed;
+}
+
bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
CHECK_EQ(call->opcode(), HloOpcode::kCall);
InstructionValueSet& value_set = GetInstructionValueSet(call);
@@ -429,6 +474,10 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
return UpdateCallValueSet(instruction);
case HloOpcode::kWhile:
return UpdateWhileValueSet(instruction);
+ case HloOpcode::kSend:
+ return UpdateSendValueSet(instruction);
+ case HloOpcode::kRecvDone:
+ return UpdateRecvDoneValueSet(instruction);
default:
// Instruction does not forward HloValues (it defines all values in its
// output). No update is necessary.
@@ -537,6 +586,12 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
GetValueSet(instruction, /*index=*/{}).AddValue(value);
};
+ // Lambda to set the value set at the given index of the output.
+ auto define_value_at = [this, &instruction](const ShapeIndex& index) {
+ HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
+ GetValueSet(instruction, index).AddValue(value);
+ };
+
switch (instruction->opcode()) {
case HloOpcode::kBitcast:
if (bitcast_defines_value_) {
@@ -577,6 +632,16 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// values flow from their operands.
define_top_level_only();
break;
+ case HloOpcode::kRecvDone:
+ // RecvDone aliases its input tuple element {0}, therefore does not
+ // define any values.
+ break;
+ case HloOpcode::kSend:
+ // Send produces a tuple of {aliased operand, U32 context}, therefore
+ // only defines the top-level tuple and the tuple element at {1}.
+ define_value_at(/*index=*/{});
+ define_value_at(/*index=*/{1});
+ break;
default:
define_all_values();
break;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index 207e553bf7..63467f3206 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -146,7 +146,9 @@ class HloDataflowAnalysis {
bool UpdateCopyValueSet(HloInstruction* copy);
bool UpdateGetTupleElementValueSet(HloInstruction* gte);
bool UpdateParameterValueSet(HloInstruction* parameter);
+ bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
bool UpdateSelectValueSet(HloInstruction* select);
+ bool UpdateSendValueSet(HloInstruction* send);
bool UpdateTupleValueSet(HloInstruction* tuple);
bool UpdateWhileValueSet(HloInstruction* xla_while);
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 4b8eb237a6..66a538fc51 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1139,6 +1139,54 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) {
analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module());
}
+TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
+ // Test that a Send forwards its operand to the output tuple at {0}.
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
+ auto send = builder.AddInstruction(
+ HloInstruction::CreateSend(param, /*channel_id=*/0));
+ auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
+ module_->AddEntryComputation(builder.Build());
+
+ bool ssa_form = GetParam();
+ const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
+
+ EXPECT_EQ(analysis.values().size(), 4);
+
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
+ EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(param)));
+}
+
+TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
+ // Test that a RecvDone forwards its operand tuple element at {0} to the
+ // output.
+ auto builder = HloComputation::Builder(TestName());
+ auto recv = builder.AddInstruction(
+ HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0));
+ auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
+ module_->AddEntryComputation(builder.Build());
+
+ bool ssa_form = GetParam();
+ const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
+
+ EXPECT_EQ(analysis.values().size(), 3);
+
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done));
+ EXPECT_THAT(HloValuesAt(recv_done),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0})));
+ EXPECT_TRUE(
+ analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
+}
+
TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) {
// A simple chain of elementwise operations. No values should interfere.
//
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 88b77ccdd0..a722d1b3d9 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -1450,6 +1450,10 @@ HloEvaluator::HloEvaluator() {
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
+
+ typed_visitors_[BF16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented("HloEvaluator: unhandled primitive type: BF16.");
+ });
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE.");
});
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 67b6e215fc..7557aaa248 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -39,16 +39,18 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
HloEvaluator();
// Evaluates an HLO module and an array of pointers to literals.
// Returns the evaluated result as a literal if successful.
- // Precondition: argument literals correspond to each input computation's
- // parameters in their post-ordering. See comment below for example.
+ // Precondition: The indices of arg_literals correspond to the parameter
+ // numbers of the HLO parameters in the computation. See comment below for an
+ // example.
StatusOr<std::unique_ptr<Literal>> Evaluate(
const HloModule& module,
tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
// Evaluates an HLO computation and an array of pointers to literals.
// Returns the evaluated result as a literal if successful.
- // Precondition: argument literals correspond to the input computation's
- // parameters in their post-ordering. For e.g., consider the following graph:
+ // Precondition: The indices of arg_literals correspond to the parameter
+ // numbers of the HLO parameters in the computation. For e.g., consider the
+ // following graph:
//
// *
// / \
@@ -57,8 +59,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// / \
// Parameter0 Constant
//
- // The input literals array will have its first literal map to Parameter0 and
- // the second map to Parameter1.
+ // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
+ // 1 in this computation. The input literals array will then have its first
+ // literal map to Parameter0 and the second map to Parameter1.
StatusOr<std::unique_ptr<Literal>> Evaluate(
const HloComputation& computation,
tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index fd162622ce..04b3059fb1 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -761,12 +761,22 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
string HloDotDumper::GetInstructionNodeInlinedOperands(
const HloInstruction* instr) {
auto stringify_constant = [](const HloInstruction* constant) {
- if (ShapeUtil::IsEffectiveScalar(constant->shape())) {
- auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex(
- constant->shape(), /*linear_index=*/0);
- return Printf("%s (%s)", constant->literal().GetAsString(elem_idx),
+ const auto& shape = constant->shape();
+
+ // Print the literal value of constants with <= K elements.
+ optional<int64> elem_count;
+ if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) {
+ elem_count = 1;
+ for (int64 dim : shape.dimensions()) {
+ *elem_count *= dim;
+ }
+ }
+ if (elem_count.has_value() && *elem_count <= 8) {
+ return Printf("%s (%s)", constant->literal().ToString(),
ShapeUtil::HumanString(constant->shape()));
}
+
+ // Otherwise, print e.g. "%constant.42 (s32[100])".
string constant_name;
if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) {
constant_name = constant->name();
@@ -933,7 +943,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kFusion:
return kGray;
case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kCrossReplicaSum:
@@ -1027,7 +1039,9 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
? ""
: StrCat("stride=", VectorString(instr->slice_strides()));
case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
return StrCat("channel_id=", instr->channel_id());
default:
return "";
@@ -1289,7 +1303,9 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
auto is_displayed = [&](const HloInstruction* instr) {
// Constants are displayed inline with their users; they're never omitted.
- return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant;
+ // Nodes in subcomputations are always shown.
+ return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant ||
+ instr->parent() != root->parent();
};
// Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 5107ac782d..1e83c69b50 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -371,20 +371,50 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
HloInstruction* operand, int64 channel_id) {
+ // Send instruction produces a tuple of {aliased operand, U32 context}.
+ Shape output_shape = ShapeUtil::MakeTupleShape(
+ {operand->shape(), ShapeUtil::MakeShape(U32, {})});
auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kSend, ShapeUtil::MakeNil()));
+ WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape));
instruction->AppendOperand(operand);
instruction->channel_id_ = channel_id;
return instruction;
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
+ HloInstruction* operand) {
+ CHECK(operand->opcode() == HloOpcode::kSend)
+ << "SendDone must take the context operand from Send";
+ auto instruction = WrapUnique(
+ new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil()));
+ instruction->AppendOperand(operand);
+ instruction->channel_id_ = operand->channel_id();
+ return instruction;
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
const Shape& shape, int64 channel_id) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRecv, shape));
+ // Recv instruction produces a tuple of {receive buffer, U32 context}.
+ Shape output_shape =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape));
instruction->channel_id_ = channel_id;
return instruction;
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
+ HloInstruction* operand) {
+ CHECK(operand->opcode() == HloOpcode::kRecv)
+ << "RecvDone must take the context operand from Recv";
+ Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0);
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape));
+ instruction->AppendOperand(operand);
+ instruction->channel_id_ = operand->channel_id();
+ return instruction;
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
@@ -908,7 +938,9 @@ RandomDistribution HloInstruction::random_distribution() const {
bool HloInstruction::HasSideEffect() const {
switch (opcode_) {
case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kTrace:
@@ -1164,7 +1196,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
new_operands[4], epsilon(), feature_index());
break;
case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
case HloOpcode::kTrace:
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
}
@@ -1557,8 +1591,10 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kSort:
- case HloOpcode::kSend:
case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
return false;
}
}
@@ -1790,7 +1826,7 @@ string HloInstruction::ToString(bool compact_operands, bool include_metadata,
if (include_metadata &&
(!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
!metadata_.source_file().empty())) {
- StrAppend(&result, " # metadata=", metadata_.ShortDebugString());
+ StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
}
return result;
}
@@ -1850,12 +1886,13 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
}
if (window_ != nullptr) {
- extra.push_back(window_util::ToString(*window_));
+ extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
}
if (padding_config_ != nullptr) {
- extra.push_back(StrCat("padding=", padding_config_->ShortDebugString()));
+ extra.push_back(
+ StrCat("padding=", xla::PaddingConfigToString(*padding_config_)));
}
- if (!slice_starts_.empty() && !slice_limits_.empty()) {
+ if (opcode() == HloOpcode::kSlice) {
std::vector<string> bounds;
bounds.reserve(slice_starts_.size());
const bool omit_stride =
@@ -1868,6 +1905,16 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
}
extra.push_back(StrCat("slice={", Join(bounds, ", "), "}"));
}
+ if (opcode() == HloOpcode::kDynamicSlice) {
+ extra.push_back(
+ StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}"));
+ }
+ if (opcode() == HloOpcode::kBatchNormTraining ||
+ opcode() == HloOpcode::kBatchNormInference ||
+ opcode() == HloOpcode::kBatchNormGrad) {
+ extra.push_back(StrCat("epsilon=", epsilon()));
+ extra.push_back(StrCat("feature_index=", feature_index()));
+ }
if (convolution_dimension_numbers_ != nullptr) {
extra.push_back(ConvolutionDimensionNumbersToString());
@@ -1891,7 +1938,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
})));
}
- if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv) {
+ if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv ||
+ opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) {
extra.push_back(StrCat("channel_id=", channel_id_));
}
@@ -2071,8 +2119,10 @@ bool HloInstruction::IsFusable() const {
case HloOpcode::kOutfeed:
case HloOpcode::kParameter:
case HloOpcode::kTrace:
- case HloOpcode::kSend:
case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
return false;
// Only fuse Rng if it is used once, otherwise the random numbers generated
// will be different in each fusion. If it is the root (user count = 0)
@@ -2279,10 +2329,14 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleCall(this);
case HloOpcode::kCustomCall:
return visitor->HandleCustomCall(this);
- case HloOpcode::kSend:
- return visitor->HandleSend(this);
case HloOpcode::kRecv:
return visitor->HandleRecv(this);
+ case HloOpcode::kRecvDone:
+ return visitor->HandleRecvDone(this);
+ case HloOpcode::kSend:
+ return visitor->HandleSend(this);
+ case HloOpcode::kSendDone:
+ return visitor->HandleSendDone(this);
// These opcodes are not handled here.
case HloOpcode::kTrace:
@@ -2841,6 +2895,40 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str());
}
+string PaddingConfigToString(const PaddingConfig& padding) {
+ bool has_interior_padding =
+ std::any_of(padding.dimensions().begin(), padding.dimensions().end(),
+ [](const PaddingConfig::PaddingConfigDimension& dim) {
+ return dim.interior_padding() != 0;
+ });
+ return Join(
+ padding.dimensions(), "x",
+ [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
+ StrAppend(
+ out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
+ has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
+ });
+}
+
+string OpMetadataToString(const OpMetadata& metadata) {
+ std::vector<string> result;
+ using tensorflow::str_util::CEscape;
+ if (!metadata.op_type().empty()) {
+ result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\""));
+ }
+ if (!metadata.op_name().empty()) {
+ result.push_back(StrCat("op_name=\"", CEscape(metadata.op_name()), "\""));
+ }
+ if (!metadata.source_file().empty()) {
+ result.push_back(
+ StrCat("source_file=\"", CEscape(metadata.source_file()), "\""));
+ }
+ if (metadata.source_line() != 0) {
+ result.push_back(StrCat("source_line=", metadata.source_line()));
+ }
+ return Join(result, " ");
+}
+
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
return os << ToString(kind);
}
@@ -2856,13 +2944,7 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const {
const auto append_dims = [&](const std::vector<string>& dims,
const Shape& shape) {
CHECK_EQ(dims.size(), ShapeUtil::Rank(shape));
- for (int64 logical = 0; logical < dims.size(); ++logical) {
- int64 physical = logical;
- if (!shape.layout().minor_to_major().empty()) {
- physical = LayoutUtil::Major(shape.layout(), logical);
- }
- result += dims[physical];
- }
+ StrAppend(&result, Join(dims, ""));
};
// lhs_dims[i] is the symbol of the logical dimension i for the lhs
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 5ff04a4888..05befe7806 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -181,18 +181,28 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::StringPiece outfeed_config);
- // Creates a send instruction with the given channel id, which sends the
- // operand data to a unique receive instruction in another computation that
- // has the same channel id.
+ // Creates an asynchronous send instruction with the given channel id, which
+ // initiates sending the operand data to a unique receive instruction in
+ // another computation that has the same channel id.
static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
int64 channel_id);
- // Creates a receive instruction with the given channel id, which receives
- // data of the given shape from a unique send instruction in another
- // computation that has the same channel id.
+ // Blocks until data transfer for the Send instruction (operand) is complete.
+ // The operand must be kSend.
+ static std::unique_ptr<HloInstruction> CreateSendDone(
+ HloInstruction* operand);
+
+ // Creates an asynchronous receive instruction with the given channel id,
+ // which allocates resources to receive data of the given shape from a unique
+ // send instruction in another computation that has the same channel id.
static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape,
int64 channel_id);
+ // Blocks until data transfer for the Recv instruction (operand) is complete
+ // and returns the receive buffer. The operand must be kRecv.
+ static std::unique_ptr<HloInstruction> CreateRecvDone(
+ HloInstruction* operand);
+
// Creates a slice instruction, where the operand is sliced by the given
// start/limit indices.
static std::unique_ptr<HloInstruction> CreateSlice(
@@ -853,6 +863,11 @@ class HloInstruction {
return *window_;
}
+ // Sets the window data in a windowed operation such as convolution.
+ void set_window(const Window& window) {
+ window_ = MakeUnique<Window>(window);
+ }
+
// Returns the padding configuration for a pad node.
//
// Precondition: opcode() == HloOpcode::kPad
@@ -1224,6 +1239,10 @@ string ToString(HloInstruction::FusionKind kind);
StatusOr<HloInstruction::FusionKind> StringToFusionKind(
const string& kind_name);
+// Custom stringification functions for protos that live inside HloInstruction.
+string PaddingConfigToString(const PaddingConfig& padding);
+string OpMetadataToString(const OpMetadata& metadata);
+
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
// Map classes that guarantee a deterministic iteration order when the key is
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index 4d4010b025..268fa0f632 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -121,6 +121,7 @@ HLO_MATCHER(Outfeed);
HLO_MATCHER(Pad);
HLO_MATCHER(Power);
HLO_MATCHER(Recv);
+HLO_MATCHER(RecvDone);
HLO_MATCHER(Reduce);
HLO_MATCHER(ReducePrecision);
HLO_MATCHER(ReduceWindow);
@@ -131,6 +132,7 @@ HLO_MATCHER(Rng);
HLO_MATCHER(Select);
HLO_MATCHER(SelectAndScatter);
HLO_MATCHER(Send);
+HLO_MATCHER(SendDone);
HLO_MATCHER(ShiftLeft);
HLO_MATCHER(ShiftRightLogical);
HLO_MATCHER(ShiftRightArithmetic);
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 6469851791..5141e7bc8d 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -85,7 +85,11 @@ class HloModule {
std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
// Return a pointer to the entry computation of the module..
- HloComputation* entry_computation() const {
+ const HloComputation* entry_computation() const {
+ CHECK_NE(nullptr, entry_computation_);
+ return entry_computation_;
+ }
+ HloComputation* entry_computation() {
CHECK_NE(nullptr, entry_computation_);
return entry_computation_;
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc
index 8974deb530..822e2f1f53 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_config.cc
@@ -39,8 +39,8 @@ void HloModuleConfig::SetDefaultComputationLayout(
}
string HloModuleConfig::compilation_cache_key() const {
- string key = tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_,
- "::hybrid=", has_hybrid_result_);
+ string key =
+ tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_);
StrAppend(&key, "::(");
std::vector<string> params;
for (const ShapeLayout& param_layout :
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index 4a7ead9c10..a5ee895e48 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -104,16 +104,6 @@ class HloModuleConfig {
// Whether to enable HLO-level profiling.
bool hlo_profiling_enabled_ = false;
- // If this flag is true, the generated executable will return a ShapedBuffer
- // holding the result of the computation. In a ShapedBuffer, tuples have their
- // structure held in host memory and the element arrays (leaves of the tuple
- // structure) stored in device memory. The ShapedBuffer is considered "hybrid"
- // because its leaves are on device but its structure is stored on
- // host. Otherwise, if this flag is false, the generated executable will
- // return a DeviceMemoryBase where the result is held entirely in device
- // memory.
- bool has_hybrid_result_ = false;
-
// Module/graph-level seed handle.
uint64 seed_ = 0;
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index d68fc20321..e0d02e0665 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -97,6 +97,7 @@ namespace xla {
V(kPower, "power") \
V(kReal, "real") \
V(kRecv, "recv") \
+ V(kRecvDone, "recv-done") \
V(kReduce, "reduce") \
V(kReducePrecision, "reduce-precision") \
V(kReduceWindow, "reduce-window") \
@@ -108,6 +109,7 @@ namespace xla {
V(kSelect, "select") \
V(kSelectAndScatter, "select-and-scatter") \
V(kSend, "send") \
+ V(kSendDone, "send-done") \
V(kShiftLeft, "shift-left") \
V(kShiftRightArithmetic, "shift-right-arithmetic") \
V(kShiftRightLogical, "shift-right-logical") \
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index c96df50e79..828be8490c 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -66,7 +66,9 @@ bool IsRematerializable(const HloInstruction* instruction) {
case HloOpcode::kInfeed:
case HloOpcode::kParameter:
case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
case HloOpcode::kTrace:
case HloOpcode::kWhile:
return false;
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index f463e57d99..158fb9a546 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#define EIGEN_USE_THREADS
#include "tensorflow/compiler/xla/service/hlo_runner.h"
@@ -19,8 +20,6 @@ limitations under the License.
#include <string>
#include <utility>
-#define EIGEN_USE_THREADS
-
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 0d019d22f5..7356663454 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
@@ -38,6 +39,15 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
}
string HloSharding::ToString() const {
+ if (IsTuple()) {
+ std::vector<string> parts;
+ parts.reserve(tuple_elements_.size());
+ for (const HloSharding& element : tuple_elements_) {
+ parts.push_back(element.ToString());
+ }
+ return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}");
+ }
+
string result = StrCat("{", (replicated_ ? " replicated" : ""),
(maximal_ ? " maximal" : ""));
@@ -53,6 +63,11 @@ string HloSharding::ToString() const {
}
bool HloSharding::UsesDevice(int64 device) const {
+ if (IsTuple()) {
+ return std::any_of(
+ tuple_elements_.begin(), tuple_elements_.end(),
+ [&](const HloSharding& s) { return s.UsesDevice(device); });
+ }
const auto& devices = tile_assignment_;
return replicated_ ||
std::find(devices.begin(), devices.end(), device) != devices.end();
@@ -61,6 +76,7 @@ bool HloSharding::UsesDevice(int64 device) const {
std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
CHECK(!ShapeUtil::IsTuple(tile_shape_));
CHECK(!maximal_);
+ CHECK(!IsTuple());
std::vector<int64> ret_index;
tile_assignment_.Each([&](tensorflow::gtl::ArraySlice<int64> index, int64 d) {
if (d == device) {
@@ -74,6 +90,7 @@ std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
int64 HloSharding::DeviceForTileIndex(
tensorflow::gtl::ArraySlice<int64> index) const {
CHECK(!replicated_);
+ CHECK(!IsTuple());
if (maximal_) {
return *tile_assignment_.begin();
}
@@ -82,7 +99,7 @@ int64 HloSharding::DeviceForTileIndex(
}
std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
- CHECK(!ShapeUtil::IsTuple(tile_shape_));
+ CHECK(!IsTuple());
std::vector<int64> index = TileIndexForDevice(device);
if (maximal_) {
@@ -97,7 +114,7 @@ std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
}
std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
- CHECK(!ShapeUtil::IsTuple(tile_shape_));
+ CHECK(!IsTuple());
CHECK(!maximal_); // Maximal shardings do not have a valid tile shape.
std::vector<int64> index = TileIndexForDevice(device);
@@ -108,13 +125,41 @@ std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
}
StatusOr<int64> HloSharding::UniqueDevice() const {
- if (!replicated_ && maximal_) {
+ if (IsTuple()) {
+ if (tuple_elements_.empty()) {
+ return tensorflow::errors::InvalidArgument(
+ "UniqueDevice() called on empty tuple");
+ }
+ std::vector<StatusOr<int64>> results;
+ std::transform(tuple_elements_.begin(), tuple_elements_.end(),
+ std::back_inserter(results),
+ [](const HloSharding& s) { return s.UniqueDevice(); });
+ if (std::all_of(results.begin(), results.end(),
+ [&](const StatusOr<int64>& s) {
+ return s.ok() && results[0].ok() &&
+ s.ValueOrDie() == results[0].ValueOrDie();
+ })) {
+ return results[0];
+ } else {
+ return tensorflow::errors::InvalidArgument(
+ "Tuple did not contain a unique device");
+ }
+ }
+ if (!replicated_ && maximal_ && !IsTuple()) {
return static_cast<int64>(*tile_assignment_.begin());
}
return tensorflow::errors::InvalidArgument(
"UniqueDevice() called on sharding that executes on multiple devices");
}
+bool HloSharding::HasUniqueDevice() const {
+ if (IsTuple()) {
+ return UniqueDevice().status().ok();
+ } else {
+ return !IsReplicated() && IsTileMaximal();
+ }
+}
+
Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
if (replicated_) {
return Status::OK();
@@ -193,9 +238,19 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
/*static*/ StatusOr<HloSharding> HloSharding::FromProto(
const OpSharding& proto) {
- if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
+ if (proto.type() == OpSharding::Type::OpSharding_Type_TUPLE) {
+ std::vector<HloSharding> tuple_shardings;
+ tuple_shardings.reserve(proto.tuple_shardings().size());
+ for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
+ TF_ASSIGN_OR_RETURN(HloSharding sharding,
+ HloSharding::FromProto(tuple_sharding_proto));
+ tuple_shardings.push_back(sharding);
+ }
+ return HloSharding(tuple_shardings);
+ } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
return Replicate();
- } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) {
+ } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL ||
+ proto.tile_assignment_devices().size() == 1) {
return HloSharding(proto.tile_assignment_devices(0));
}
// Some versions of gcc cannot infer the TileAssignment constructor from a
@@ -212,6 +267,15 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
OpSharding HloSharding::ToProto() const {
OpSharding result;
+
+ if (IsTuple()) {
+ for (const HloSharding& element : tuple_elements_) {
+ *result.add_tuple_shardings() = element.ToProto();
+ }
+ result.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
+ return result;
+ }
+
*result.mutable_tile_shape() = tile_shape_;
for (int64 dim : tile_assignment_.dimensions()) {
result.add_tile_assignment_dimensions(dim);
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index d7ada30c70..dbd16b7c9d 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/hash/hash.h"
@@ -67,6 +68,18 @@ class HloSharding {
// `num_tiles` tiles.
static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles);
+ // Creates a new sharding for a tuple type. The given ShapeTree must have
+ // elements for every leaf shape contained in the tuple.
+ static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings) {
+ std::vector<HloSharding> flattened_list;
+ flattened_list.reserve(
+ std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end()));
+ for (const auto& index_to_sharding : sub_shardings.leaves()) {
+ flattened_list.push_back(index_to_sharding.second);
+ }
+ return HloSharding(flattened_list);
+ }
+
// Create a new sharding from a protobuf OpSharding.
static StatusOr<HloSharding> FromProto(const OpSharding& proto);
@@ -76,47 +89,93 @@ class HloSharding {
// Validate that this sharding can be applied to a tensor with shape `shape`.
Status Validate(const Shape& shape, int64 num_devices) const;
+ // Returns true if the sharding has tuple type.
+ bool IsTuple() const { return tuple_; }
+
// Returns true if the sharding is trivial: replicate on all devices.
- bool IsReplicated() const { return replicated_; }
+ bool IsReplicated() const {
+ if (!IsTuple()) {
+ return replicated_;
+ }
+ return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
+ [](const HloSharding& s) { return s.IsReplicated(); });
+ }
// Returns true if the tile size is the same as the input size.
- bool IsTileMaximal() const { return maximal_; }
+ bool IsTileMaximal() const {
+ if (!IsTuple()) {
+ return maximal_;
+ }
+ return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
+ [](const HloSharding& s) { return s.IsTileMaximal(); });
+ }
// Returns true if the sharding defines an operation on the given device.
bool UsesDevice(int64 device) const;
// Returns the tile that should be executed on the given device.
+ // REQUIRES: !IsTuple()
std::vector<int64> TileIndexForDevice(int64 device) const;
// Returns the device that should execute the given tile.
// It is an error to call this if is_replicated() is true.
+ // REQUIRES: !IsTuple()
int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const;
// Given a device ID, returns the offset within the input space of the
// tile that should be executed on the given core. This returns the lower
// extent of the tile in the input space.
+ // REQUIRES: !IsTuple()
std::vector<int64> TileOffsetForDevice(int64 device) const;
// Given a device ID, returns the limit within the input space of the
// tile that should be executed on the given core. This returns the upper
// extent of the tile in the input space.
+ // REQUIRES: !IsTuple()
std::vector<int64> TileLimitForDevice(int64 device) const;
// Returns the single device this op operates on.
- // Requires !Replicated() && IsTileMaximal().
+ // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal()
StatusOr<int64> UniqueDevice() const;
// Returns true if this op only uses a single device.
- bool HasUniqueDevice() const { return !IsReplicated() && IsTileMaximal(); }
+ bool HasUniqueDevice() const;
+
+ // Returns the ShapeTree containing the shardings for each element of this
+ // tuple, if IsTuple, or a ShapeTree with a single element containing this
+ // sharding. Only the leaf elements are populated. This creates a new
+ // ShapeTree object so is not cheap.
+ ShapeTree<HloSharding> GetAsShapeTree(const Shape& shape) const {
+ if (IsTuple()) {
+ ShapeTree<HloSharding> result(shape, HloSharding::Replicate());
+ CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()),
+ tuple_elements_.size());
+ auto it = tuple_elements_.begin();
+ for (auto& index_to_sharding : result.leaves()) {
+ index_to_sharding.second = *it++;
+ }
+ return result;
+ } else {
+ return ShapeTree<HloSharding>(shape, *this);
+ }
+ }
bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) &&
- tile_assignment_ == other.tile_assignment_;
+ tile_assignment_ == other.tile_assignment_ &&
+ tuple_elements_ == other.tuple_elements_;
}
bool operator!=(const HloSharding& other) const { return !(*this == other); }
size_t Hash() const {
+ if (!tuple_) {
+ size_t h = 0;
+ for (const auto& element : tuple_elements_) {
+ h = tensorflow::Hash64Combine(h, element.Hash());
+ }
+ return h;
+ }
if (replicated_) {
return 0;
}
@@ -131,33 +190,47 @@ class HloSharding {
}
// Gets the tile shape.
- // It is an error to call this if IsTileMaximal() is true.
+ // REQUIRES: !IsTileMaximal() && !IsTuple()
const Shape& tile_shape() const { return tile_shape_; }
// Gets the tile assignment tensor.
- // It is an error to call this if IsReplicated() is true.
+ // REQUIRES: !IsReplicated() && !IsTuple()
const Array<int64>& tile_assignment() const { return tile_assignment_; }
private:
HloSharding()
: replicated_(true),
maximal_(true),
+ tuple_(false),
tile_shape_(),
tile_assignment_({0}) {}
explicit HloSharding(int64 device_id)
: replicated_(false),
maximal_(true),
+ tuple_(false),
tile_shape_(),
tile_assignment_({1}, device_id) {}
HloSharding(const Shape& tile_shape, const Array<int64>& tile_assignment)
: replicated_(false),
maximal_(false),
+ tuple_(false),
tile_shape_(tile_shape),
tile_assignment_(tile_assignment) {}
+ HloSharding(const std::vector<HloSharding>& tuple_shardings)
+ : replicated_(false),
+ maximal_(false),
+ tuple_(true),
+ tile_assignment_({0}),
+ tuple_elements_(tuple_shardings) {}
bool replicated_;
bool maximal_;
+ bool tuple_;
Shape tile_shape_;
Array<int64> tile_assignment_;
+ // Only non-empty when tuple_ is true, but because empty tuples are allowed
+ // may also be empty even then. This is a flattened list of all the leaf
+ // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order).
+ std::vector<HloSharding> tuple_elements_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index d0a20471a0..3161dda271 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -70,6 +70,11 @@ TEST_F(HloShardingTest, DevicePlacement) {
/*num_devices=*/6));
EXPECT_IS_NOT_OK(
sharding.Validate(ShapeUtil::MakeShape(U32, {4}), /*num_devices=*/5));
+
+ ShapeTree<HloSharding> shape_tree =
+ sharding.GetAsShapeTree(ShapeUtil::MakeShape(U32, {4}));
+ EXPECT_EQ(shape_tree.element({}), sharding);
+ EXPECT_TRUE(shape_tree.IsLeaf({}));
}
TEST_F(HloShardingTest, Tile) {
@@ -132,6 +137,29 @@ TEST_F(HloShardingTest, Tile) {
}
}
+TEST_F(HloShardingTest, NestedTuple) {
+ // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6])
+ Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({
+ ShapeUtil::MakeShape(F32, {}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}),
+ ShapeUtil::MakeShape(F32, {4, 6}),
+ });
+
+ OpSharding proto;
+ proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
+ *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto();
+ *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto();
+ *proto.add_tuple_shardings() = HloSharding::AssignDevice(1).ToProto();
+ HloSharding tuple_sharding =
+ HloSharding::FromProto(proto).ConsumeValueOrDie();
+
+ ShapeTree<HloSharding> shape_tree =
+ tuple_sharding.GetAsShapeTree(nested_tuple_shape);
+ EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate());
+ EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0));
+ EXPECT_EQ(shape_tree.element({2}), HloSharding::AssignDevice(1));
+}
+
TEST_F(HloShardingTest, Hash) {
auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) {
if (a.Hash() != b.Hash()) {
@@ -184,6 +212,51 @@ TEST_F(HloShardingTest, Hash) {
MakeArray({2, 2}, {0, 3, 1, 2}));
EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
}
+
+ HloSharding default_sharding = HloSharding::Replicate();
+ {
+ ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
+ default_sharding);
+ HloSharding sharding1 = HloSharding::Replicate();
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree);
+ EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
+ }
+
+ {
+ ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
+ default_sharding);
+ HloSharding sharding1 = HloSharding::Tuple(shape_tree);
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree);
+ EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
+ }
+
+ {
+ ShapeTree<HloSharding> shape_tree1(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree1.mutable_element({0}) = HloSharding::Replicate();
+ ShapeTree<HloSharding> shape_tree2(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
+ HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
+ EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
+ }
+
+ {
+ ShapeTree<HloSharding> shape_tree1(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0);
+ ShapeTree<HloSharding> shape_tree2(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
+ HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
+ EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
+ }
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index c1aa655401..c938450891 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -270,12 +270,40 @@ class ShapeVerifier : public DfsHloVisitor {
pad->padding_config()));
}
- Status HandleSend(HloInstruction*) override {
- return tensorflow::Status::OK();
+ Status HandleSend(HloInstruction* send) override {
+ TF_RET_CHECK(send->users().size() == 1);
+ const HloInstruction* send_done = send->users()[0];
+ TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
+ TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
+ return CheckShape(
+ send, ShapeUtil::MakeTupleShape(
+ {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})}));
}
- Status HandleRecv(HloInstruction*) override {
- return tensorflow::Status::OK();
+ Status HandleSendDone(HloInstruction* send_done) override {
+ TF_RET_CHECK(send_done->operands().size() == 1);
+ const HloInstruction* send = send_done->operand(0);
+ TF_RET_CHECK(send->opcode() == HloOpcode::kSend);
+ TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
+ return CheckShape(send_done, ShapeUtil::MakeNil());
+ }
+
+ Status HandleRecv(HloInstruction* recv) override {
+ TF_RET_CHECK(recv->users().size() == 1);
+ const HloInstruction* recv_done = recv->users()[0];
+ TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
+ TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
+ return CheckShape(recv,
+ ShapeUtil::MakeTupleShape(
+ {recv_done->shape(), ShapeUtil::MakeShape(U32, {})}));
+ }
+
+ Status HandleRecvDone(HloInstruction* recv_done) override {
+ TF_RET_CHECK(recv_done->operands().size() == 1);
+ const HloInstruction* recv = recv_done->operand(0);
+ TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv);
+ TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
+ return CheckShape(recv_done, recv->shape().tuple_shapes(0));
}
Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override {
@@ -365,6 +393,19 @@ class ShapeVerifier : public DfsHloVisitor {
instruction->opcode(), instruction->operands()));
}
+ // Checks if the given two instructions shares the same channel id.
+ Status CheckSameChannel(const HloInstruction* instr1,
+ const HloInstruction* instr2) {
+ if (instr1->channel_id() != instr2->channel_id()) {
+ return FailedPrecondition(
+ "Expected to have the same channel id, actual channel ids are: %s "
+ "(%lld), %s (%lld)",
+ instr1->ToString().c_str(), instr1->channel_id(),
+ instr2->ToString().c_str(), instr2->channel_id());
+ }
+ return tensorflow::Status::OK();
+ }
+
// Returns the size of a Shape in bytes.
const std::function<int64(const Shape&)> shape_size_fn_;
};
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 0d1b7bc109..dea47b1fd7 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -113,7 +113,9 @@ namespace xla {
case HloOpcode::kTrace:
case HloOpcode::kWhile:
case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
return true;
}
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 86dee8462f..96f937caf9 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -89,7 +89,7 @@ StatusOr<se::DeviceMemoryBase> InterpreterExecutable::ExecuteOnStream(
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
- HloComputation* computation = module().entry_computation();
+ const HloComputation* computation = module().entry_computation();
if (computation->num_parameters() != arguments.size()) {
return tensorflow::errors::Internal(
"Mismatch between argument count and graph parameter count.");
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index c39ff52230..d51c0d1dfb 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -131,10 +131,10 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
for (auto& minor_to_major : minor_to_majors) {
auto builder = HloComputation::Builder(TestName());
- auto constant_literal1 = test_utils::CreateR2LiteralWithLayout<float>(
- {{1.0, 2.0}, {3.0, 4.0}}, minor_to_major);
- auto constant_literal2 = test_utils::CreateR2LiteralWithLayout<float>(
- {{5.0, 6.0}, {7.0, 8.0}}, minor_to_major);
+ auto constant_literal1 = Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
+ auto constant_literal2 = Literal::CreateR2WithLayout<float>(
+ {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
Shape ashape = constant_literal1->shape();
auto constant1 = builder.AddInstruction(
@@ -181,12 +181,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
// Verify the layouts of a tuple are assigned properly (the element layouts
// match their source).
auto builder = HloComputation::Builder(TestName());
- auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- {0, 1})));
- auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- {1, 0})));
+ auto constant0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
@@ -218,12 +218,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
TEST_F(LayoutAssignmentTest, TupleSelect) {
// Verify layouts of a select with tuple operands is assigned properly.
auto builder = HloComputation::Builder(TestName());
- auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- {0, 1})));
- auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- {1, 0})));
+ auto constant0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple0 = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
auto tuple1 = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc
index c27a8956a7..53d88eda7a 100644
--- a/tensorflow/compiler/xla/service/liveness_util.cc
+++ b/tensorflow/compiler/xla/service/liveness_util.cc
@@ -215,7 +215,8 @@ bool CanShareOperandBufferWithUser(
auto add_operand_it =
std::find_if(add->operands().begin(), add->operands().end(),
[&](HloInstruction* operand) {
- return operand->opcode() == HloOpcode::kDot ||
+ return operand->opcode() == HloOpcode::kConvolution ||
+ operand->opcode() == HloOpcode::kDot ||
(operand->opcode() == HloOpcode::kFusion &&
operand->fusion_kind() ==
HloInstruction::FusionKind::kTransposeDot);
@@ -294,7 +295,8 @@ bool CanShareOperandBufferWithUser(HloInstruction* operand,
auto add_operand_it =
std::find_if(add->operands().begin(), add->operands().end(),
[&](HloInstruction* operand) {
- return operand->opcode() == HloOpcode::kDot ||
+ return operand->opcode() == HloOpcode::kConvolution ||
+ operand->opcode() == HloOpcode::kDot ||
(operand->opcode() == HloOpcode::kFusion &&
operand->fusion_kind() ==
HloInstruction::FusionKind::kTransposeDot);
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index 075d4a1ab5..8f24bb1718 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -155,6 +155,30 @@ cc_library(
],
)
+cc_library(
+ name = "vector_support_library",
+ srcs = ["vector_support_library.cc"],
+ hdrs = ["vector_support_library.h"],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
+ name = "kernel_support_library",
+ srcs = ["kernel_support_library.cc"],
+ hdrs = ["kernel_support_library.h"],
+ deps = [
+ ":llvm_loop",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
# -----------------------------------------------------------------------------
filegroup(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
new file mode 100644
index 0000000000..29cc0f81bd
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
+
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
+
+namespace xla {
+void KernelSupportLibrary::For(
+ tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ llvm::Value* step,
+ const std::function<void(llvm::Value*, bool)>& for_body_generator) {
+ If(ir_builder_->CreateICmpSLT(start, end), [&]() {
+ for_body_generator(start, /*is_first_iteration=*/true);
+ For(name, ir_builder_->CreateAdd(start, step), end, step,
+ [&](llvm::Value* iv) { for_body_generator(iv, false); });
+ });
+}
+
+void KernelSupportLibrary::For(
+ tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ llvm::Value* step, bool peel_first_iteration,
+ const std::function<void(llvm::Value*, llvm::Value*)>& for_body_generator) {
+ if (peel_first_iteration) {
+ For(name, start, end, step, true,
+ [&](llvm::Value* indvar, bool is_first_iteration) {
+ for_body_generator(indvar, ir_builder_->getInt1(is_first_iteration));
+ });
+ } else {
+ std::unique_ptr<llvm_ir::ForLoop> loop = llvm_ir::ForLoop::EmitForLoop(
+ name, start, end, step, ir_builder_,
+ /*prevent_unrolling=*/prevent_unrolling_,
+ /*prevent_vectorization=*/prevent_vectorization_);
+ ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back());
+ for_body_generator(loop->GetIndVarValue(),
+ /*is_first_iteration=*/ir_builder_->CreateICmpEQ(
+ loop->GetIndVarValue(), start));
+ llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_);
+ }
+}
+
+void KernelSupportLibrary::If(
+ llvm::Value* condition, const std::function<void()>& true_block_generator,
+ const std::function<void()>& false_block_generator) {
+ llvm_ir::LlvmIfData if_data =
+ llvm_ir::EmitIfThenElse(condition, "", ir_builder_);
+ ir_builder_->SetInsertPoint(&if_data.true_block->back());
+ true_block_generator();
+ ir_builder_->SetInsertPoint(&if_data.false_block->back());
+ false_block_generator();
+ llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_);
+}
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
new file mode 100644
index 0000000000..9bafb7b577
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
@@ -0,0 +1,128 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
+
+#include <string>
+
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace xla {
+// A thin wrapper around llvm_loop.h to make code generating structured control
+// flow more readable.
+class KernelSupportLibrary {
+ public:
+ // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR.
+ // If `prevent_unrolling` is true then unrolling is explicitly disabled on
+ // every loop generated by this instance of KernelSupportLibrary.
+ explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder,
+ bool prevent_unrolling = true,
+ bool prevent_vectorization = true)
+ : ir_builder_(ir_builder),
+ prevent_unrolling_(prevent_unrolling),
+ prevent_vectorization_(prevent_vectorization) {}
+
+ // Generates the following control flow structure:
+ //
+ // if (`start` < `end`) {
+ // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/true)`;
+ // for (i64 i = `start` + `step`; i s< `end`; i += `step`)
+ // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`;
+ // }
+ void For(
+ tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ llvm::Value* step,
+ const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
+ for_body_generator);
+
+ void For(
+ tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
+ for_body_generator) {
+ For(name, /*start=*/ir_builder_->getInt64(start),
+ /*end=*/ir_builder_->getInt64(end),
+ /*step=*/ir_builder_->getInt64(step), for_body_generator);
+ }
+
+ // Generates the following control flow structure if `peel_first_iteration` is
+ // true:
+ //
+ // if (`start` < `end`) {
+ // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/,true)`;
+ // for (i64 i = `start` + `step`; i s< `end`; i += `step`)
+ // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/,false)`;
+ // }
+ //
+ // and the following if `peel_first_iteration` is false:
+ //
+ // for (i64 i = `start`; i s< `end`; i += `step`)
+ // `for_body_generator(/*ind_var=*/,i,
+ // /*is_first_iteration=*/,(i != `start`))`;
+ void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ llvm::Value* step, bool peel_first_iteration,
+ const std::function<void(llvm::Value* ind_var,
+ llvm::Value* is_first_iteration)>&
+ for_body_generator);
+
+ void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ int64 step, bool peel_first_iteration,
+ const std::function<void(llvm::Value* ind_var,
+ llvm::Value* is_first_iteration)>&
+ for_body_generator) {
+ For(name, /*start=*/start, /*end=*/end,
+ /*step=*/ir_builder_->getInt64(step), peel_first_iteration,
+ for_body_generator);
+ }
+
+ void For(
+ tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ llvm::Value* step,
+ const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
+ For(name, start, end, step,
+ /*peel_first_iteration=*/false,
+ [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); });
+ }
+
+ void For(
+ tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
+ For(name, /*start=*/ir_builder_->getInt64(start),
+ /*end=*/ir_builder_->getInt64(end),
+ /*step=*/ir_builder_->getInt64(step), for_body_generator);
+ }
+
+ // Generates the following control flow structure:
+ //
+ // if (`condition`)
+ // `true_block_generator()`;
+ // else
+ // `false_block_generator()`;
+ void If(llvm::Value* condition,
+ const std::function<void()>& true_block_generator,
+ const std::function<void()>& false_block_generator = []() {});
+
+ private:
+ llvm::IRBuilder<>* ir_builder_;
+ bool prevent_unrolling_;
+ bool prevent_vectorization_;
+};
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
index 83d35cb9ef..7b227ce294 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
@@ -34,21 +34,24 @@ namespace llvm_ir {
ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
llvm::Value* start_index, llvm::Value* end_index,
- llvm::Value* step, bool prevent_unrolling)
+ llvm::Value* step, bool prevent_unrolling,
+ bool prevent_vectorization)
: prefix_(prefix.ToString()),
suffix_(suffix.ToString()),
start_index_(start_index),
end_index_(end_index),
step_(step),
insert_before_bb_(nullptr),
- prevent_unrolling_(prevent_unrolling) {}
+ prevent_unrolling_(prevent_unrolling),
+ prevent_vectorization_(prevent_vectorization) {}
/* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
tensorflow::StringPiece prefix, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder,
- bool prevent_unrolling) {
- std::unique_ptr<ForLoop> loop(new ForLoop(
- prefix, /*suffix=*/"", start_index, end_index, step, prevent_unrolling));
+ bool prevent_unrolling, bool prevent_vectorization) {
+ std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
+ end_index, step, prevent_unrolling,
+ prevent_vectorization));
loop->Emit(ir_builder);
return loop;
}
@@ -127,14 +130,12 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) {
ir_builder->CreateStore(indvar_inc, indvar_address);
llvm::BranchInst* back_branch = ir_builder->CreateBr(header_bb_);
- if (prevent_unrolling_) {
- const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
- llvm::LLVMContext* ctx = &back_branch->getContext();
-
+ std::vector<llvm::Metadata*> loop_metadata = GetLoopMetadata(ir_builder);
+ if (!loop_metadata.empty()) {
+ llvm::LLVMContext* ctx = &start_index_->getContext();
auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None);
- auto no_unroll_node = llvm::MDNode::get(
- *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)});
- auto loop_id = llvm::MDNode::get(*ctx, {temp_node.get(), no_unroll_node});
+ loop_metadata.insert(loop_metadata.begin(), temp_node.get());
+ auto loop_id = llvm::MDNode::get(*ctx, loop_metadata);
loop_id->replaceOperandWith(0, loop_id);
back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id);
}
@@ -143,6 +144,27 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) {
ir_builder->SetInsertPoint(exit_bb_);
}
+std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(
+ llvm::IRBuilder<>* ir_builder) {
+ const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
+ const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable";
+ llvm::LLVMContext* ctx = &start_index_->getContext();
+
+ std::vector<llvm::Metadata*> result;
+ if (prevent_unrolling_) {
+ result.push_back(llvm::MDNode::get(
+ *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}));
+ }
+
+ if (prevent_vectorization_) {
+ result.push_back(llvm::MDNode::get(
+ *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName),
+ llvm::ConstantAsMetadata::get(ir_builder->getFalse())}));
+ }
+
+ return result;
+}
+
string ForLoop::GetQualifiedName(tensorflow::StringPiece name) {
return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
}
@@ -156,23 +178,25 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
llvm::Value* start_index,
llvm::Value* end_index,
- bool prevent_unrolling) {
+ bool prevent_unrolling,
+ bool prevent_vectorization) {
return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1),
- prevent_unrolling);
+ prevent_unrolling, prevent_vectorization);
}
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
llvm::Value* start_index,
llvm::Value* end_index,
llvm::Value* stride,
- bool prevent_unrolling) {
+ bool prevent_unrolling,
+ bool prevent_vectorization) {
if (inner_loop_body_bb_ != nullptr) {
// Create this loop inside the previous one.
ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
}
std::unique_ptr<ForLoop> loop(new ForLoop(
/*prefix=*/name_, suffix, start_index, end_index, stride,
- prevent_unrolling));
+ prevent_unrolling, prevent_vectorization));
loop->Emit(ir_builder_);
if (outer_loop_preheader_bb_ == nullptr) {
@@ -191,20 +215,24 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index,
tensorflow::StringPiece suffix,
- bool prevent_unrolling) {
+ bool prevent_unrolling,
+ bool prevent_vectorization) {
CHECK_LE(start_index, end_index);
return AddLoop(suffix, ir_builder_->getInt64(start_index),
- ir_builder_->getInt64(end_index), prevent_unrolling);
+ ir_builder_->getInt64(end_index), prevent_unrolling,
+ prevent_vectorization);
}
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index, int64 stride,
tensorflow::StringPiece suffix,
- bool prevent_unrolling) {
+ bool prevent_unrolling,
+ bool prevent_vectorization) {
CHECK_LE(start_index, end_index);
return AddLoop(suffix, ir_builder_->getInt64(start_index),
ir_builder_->getInt64(end_index),
- ir_builder_->getInt64(stride), prevent_unrolling);
+ ir_builder_->getInt64(stride), prevent_unrolling,
+ prevent_vectorization);
}
IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
index 90f7c7df9e..20069ce5a2 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
@@ -71,12 +71,10 @@ class ForLoop {
//
// If `prevent_unrolling` is true then emit metadata that directs LLVM to not
// unroll the generated loop.
- static std::unique_ptr<ForLoop> EmitForLoop(tensorflow::StringPiece prefix,
- llvm::Value* start_index,
- llvm::Value* end_index,
- llvm::Value* step,
- llvm::IRBuilder<>* ir_builder,
- bool prevent_unrolling = false);
+ static std::unique_ptr<ForLoop> EmitForLoop(
+ tensorflow::StringPiece prefix, llvm::Value* start_index,
+ llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder,
+ bool prevent_unrolling = false, bool prevent_vectorization = false);
// The names of the blocks follow LLVM's conventions. Control flow amongst the
// blocks for the example C code looks like:
@@ -130,7 +128,7 @@ class ForLoop {
ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step,
- bool prevent_unrolling);
+ bool prevent_unrolling, bool prevent_vectorization);
// Emit the loop at the insert point of the builder.
void Emit(llvm::IRBuilder<>* ir_builder);
@@ -142,6 +140,10 @@ class ForLoop {
// they are set.
string GetQualifiedName(tensorflow::StringPiece name);
+ // Return a list of metadata nodes that should be associated with the
+ // llvm::Loop for this `ForLoop`.
+ std::vector<llvm::Metadata*> GetLoopMetadata(llvm::IRBuilder<>* ir_builder);
+
string prefix_;
string suffix_;
llvm::Value* start_index_;
@@ -160,6 +162,7 @@ class ForLoop {
llvm::BasicBlock* exit_bb_;
llvm::Value* indvar_;
bool prevent_unrolling_;
+ bool prevent_vectorization_;
TF_DISALLOW_COPY_AND_ASSIGN(ForLoop);
};
@@ -185,24 +188,28 @@ class ForLoopNest {
std::unique_ptr<ForLoop> AddLoop(tensorflow::StringPiece suffix,
llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* stride,
- bool prevent_unrolling = false);
+ bool prevent_unrolling = false,
+ bool prevent_vectorization = false);
// Like the above, except that it defaults to a stride of one.
std::unique_ptr<ForLoop> AddLoop(tensorflow::StringPiece suffix,
llvm::Value* start_index,
llvm::Value* end_index,
- bool prevent_unrolling = false);
+ bool prevent_unrolling = false,
+ bool prevent_vectorization = false);
// A convenient wrapper of the other flavor of AddLoop. The given start and
// end index are constant.
std::unique_ptr<ForLoop> AddLoop(int64 start_index, int64 end_index,
int64 stride, tensorflow::StringPiece suffix,
- bool prevent_unrolling = false);
+ bool prevent_unrolling = false,
+ bool prevent_vectorization = false);
// Like the above, except that it defaults to a stride of one.
std::unique_ptr<ForLoop> AddLoop(int64 start_index, int64 end_index,
tensorflow::StringPiece suffix,
- bool prevent_unrolling = false);
+ bool prevent_unrolling = false,
+ bool prevent_vectorization = false);
// Add loops to iterate through the indices within the specified
// shape. The returned index collects the induction variables of the
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index 956c0d5f05..d95409e399 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -537,6 +537,14 @@ void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
}
+void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
+ if (llvm::Instruction* terminator = blk->getTerminator()) {
+ builder->SetInsertPoint(terminator);
+ } else {
+ builder->SetInsertPoint(blk);
+ }
+}
+
llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
llvm::IRBuilder<>* builder) {
auto size = rotand->getType()->getPrimitiveSizeInBits();
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
index 304192b58e..f70d9f88b3 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
@@ -243,6 +243,8 @@ llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
+void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
+
// Create a bitwise rotation of `rotand` by `rotor`.
llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
llvm::IRBuilder<>* builder);
diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc
new file mode 100644
index 0000000000..e8c6a83618
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc
@@ -0,0 +1,150 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h"
+
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+
+namespace xla {
+VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type,
+ int64 vector_size,
+ llvm::IRBuilder<>* ir_builder,
+ std::string name)
+ : vector_size_(vector_size),
+ primitive_type_(primitive_type),
+ ir_builder_(ir_builder),
+ name_(std::move(name)) {
+ scalar_type_ = llvm_ir::PrimitiveTypeToIrType(
+ primitive_type, ir_builder_->GetInsertBlock()->getModule());
+ scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_);
+ vector_type_ = llvm::VectorType::get(scalar_type_, vector_size);
+ vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_);
+}
+
+llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) {
+ if (scalar_type_->isFloatingPointTy()) {
+ return ir_builder()->CreateFMul(lhs, rhs, name());
+ } else {
+ return ir_builder()->CreateMul(lhs, rhs, name());
+ }
+}
+
+llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) {
+ if (scalar_type_->isFloatingPointTy()) {
+ return ir_builder()->CreateFAdd(lhs, rhs, name());
+ } else {
+ return ir_builder()->CreateAdd(lhs, rhs, name());
+ }
+}
+
+llvm::Value* VectorSupportLibrary::ComputeOffsetPointer(
+ llvm::Value* base_pointer, llvm::Value* offset_elements) {
+ if (base_pointer->getType() != scalar_pointer_type()) {
+ base_pointer = ir_builder()->CreateBitCast(base_pointer,
+ scalar_pointer_type(), name());
+ }
+ return ir_builder()->CreateInBoundsGEP(base_pointer, {offset_elements},
+ name());
+}
+
+llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) {
+ if (pointer->getType() != vector_pointer_type()) {
+ pointer =
+ ir_builder()->CreateBitCast(pointer, vector_pointer_type(), name());
+ }
+ return ir_builder()->CreateAlignedLoad(
+ pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
+}
+
+llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) {
+ if (pointer->getType() != scalar_pointer_type()) {
+ pointer =
+ ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
+ }
+ return ir_builder()->CreateAlignedLoad(
+ pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
+}
+
+void VectorSupportLibrary::StoreVector(llvm::Value* value,
+ llvm::Value* pointer) {
+ if (pointer->getType() != vector_pointer_type()) {
+ pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type());
+ }
+ ir_builder()->CreateAlignedStore(
+ value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
+}
+
+void VectorSupportLibrary::StoreScalar(llvm::Value* value,
+ llvm::Value* pointer) {
+ if (pointer->getType() != scalar_pointer_type()) {
+ pointer =
+ ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
+ }
+ ir_builder()->CreateAlignedStore(
+ value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
+}
+
+llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) {
+ if (pointer->getType() != scalar_pointer_type()) {
+ pointer =
+ ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
+ }
+ return ir_builder()->CreateVectorSplat(
+ vector_size(), ir_builder()->CreateLoad(pointer), name());
+}
+
+llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) {
+ llvm::SmallVector<llvm::Constant*, 32> mask(vector_size(), nullptr);
+ for (unsigned i = vector_size(); i != 1; i >>= 1) {
+ // On every iteration, we shuffle half of the remaining lanes to the top
+ // half of shuffle, and add two old and the new vector.
+
+ for (unsigned j = 0; j < vector_size(); ++j) {
+ if (j < (i / 2)) {
+ mask[j] = ir_builder()->getInt32(i / 2 + j);
+ } else {
+ mask[j] = llvm::UndefValue::get(ir_builder()->getInt32Ty());
+ }
+ }
+
+ llvm::Value* half_remaining_lanes = ir_builder()->CreateShuffleVector(
+ vector, llvm::UndefValue::get(vector_type()),
+ llvm::ConstantVector::get(mask), "");
+ vector = Add(vector, half_remaining_lanes);
+ }
+
+ return ir_builder()->CreateExtractElement(vector, ir_builder()->getInt32(0),
+ name());
+}
+
+llvm::Value* VectorSupportLibrary::GetZeroVector() {
+ return llvm::Constant::getNullValue(vector_type());
+}
+
+llvm::Value* VectorSupportLibrary::GetZeroScalar() {
+ return llvm::Constant::getNullValue(scalar_type());
+}
+
+LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder)
+ : ir_builder_(ir_builder) {
+ alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_);
+}
+
+llvm::Value* LlvmVariable::Get() { return ir_builder_->CreateLoad(alloca_); }
+
+void LlvmVariable::Set(llvm::Value* new_value) {
+ ir_builder_->CreateStore(new_value, alloca_);
+}
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h
new file mode 100644
index 0000000000..3072677ab0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h
@@ -0,0 +1,174 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_
+
+#include <string>
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+// A thin wrapper around llvm_util.h to make code generating vector math flow
+// more readable.
+class VectorSupportLibrary {
+ public:
+ // This VectorSupportLibrary instance remembers `primitive_type` and
+ // `vector_size`, and these are implicitly used by the methods on this
+ // instance (i.e. LoadVector will load a vector of type <`vector_size` x
+ // `primitive_type`>).
+ VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size,
+ llvm::IRBuilder<>* ir_builder, std::string name);
+
+ llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs);
+ llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
+ return Mul(ir_builder()->getInt64(lhs), rhs);
+ }
+
+ llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
+ llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
+ return Add(ir_builder()->getInt64(lhs), rhs);
+ }
+
+ llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
+ return Add(c, Mul(a, b));
+ }
+
+ llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
+ llvm::Value* offset_elements);
+ llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
+ int64 offset_elements) {
+ return ComputeOffsetPointer(base_pointer,
+ ir_builder()->getInt64(offset_elements));
+ }
+
+ llvm::Value* LoadVector(llvm::Value* pointer);
+
+ llvm::Value* LoadVector(llvm::Value* base_pointer,
+ llvm::Value* offset_elements) {
+ return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements));
+ }
+
+ llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) {
+ return LoadVector(base_pointer, ir_builder()->getInt64(offset_elements));
+ }
+
+ llvm::Value* LoadScalar(llvm::Value* pointer);
+
+ llvm::Value* LoadScalar(llvm::Value* base_pointer,
+ llvm::Value* offset_elements) {
+ return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements));
+ }
+
+ llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) {
+ return LoadScalar(base_pointer, ir_builder()->getInt64(offset_elements));
+ }
+
+ void StoreVector(llvm::Value* value, llvm::Value* pointer);
+
+ void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
+ llvm::Value* offset_elements) {
+ StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements));
+ }
+
+ void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
+ int64 offset_elements) {
+ StoreVector(value, base_pointer, ir_builder()->getInt64(offset_elements));
+ }
+
+ void StoreScalar(llvm::Value* value, llvm::Value* pointer);
+ void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
+ llvm::Value* offset_elements) {
+ StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements));
+ }
+
+ void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
+ int64 offset_elements) {
+ StoreScalar(base_pointer, ir_builder()->getInt64(offset_elements));
+ }
+
+ llvm::Value* LoadBroadcast(llvm::Value* pointer);
+ llvm::Value* LoadBroadcast(llvm::Value* base_pointer,
+ llvm::Value* offset_elements) {
+ return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements));
+ }
+ llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) {
+ return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements));
+ }
+
+ llvm::Value* AddReduce(llvm::Value* vector);
+
+ llvm::Value* GetZeroVector();
+ llvm::Value* GetZeroScalar();
+
+ llvm::IRBuilder<>* ir_builder() const { return ir_builder_; }
+ int64 vector_size() const { return vector_size_; }
+ llvm::Type* vector_type() const { return vector_type_; }
+ llvm::Type* vector_pointer_type() const { return vector_pointer_type_; }
+ llvm::Type* scalar_type() const { return scalar_type_; }
+ llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; }
+
+ const std::string& name() const { return name_; }
+
+ private:
+ int64 vector_size_;
+ PrimitiveType primitive_type_;
+ llvm::IRBuilder<>* ir_builder_;
+ llvm::Type* vector_type_;
+ llvm::Type* vector_pointer_type_;
+ llvm::Type* scalar_type_;
+ llvm::Type* scalar_pointer_type_;
+ std::string name_;
+};
+
+// This wraps an alloca-backed stack variable which LLVM's SSA construction pass
+// can later convert to a SSA value.
+class LlvmVariable {
+ public:
+ LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder);
+
+ llvm::Value* Get();
+ void Set(llvm::Value* new_value);
+
+ private:
+ llvm::AllocaInst* alloca_;
+ llvm::IRBuilder<>* ir_builder_;
+};
+
+class VectorVariable : public LlvmVariable {
+ public:
+ VectorVariable(VectorSupportLibrary* vector_support,
+ llvm::Value* initial_value)
+ : LlvmVariable(vector_support->vector_type(),
+ vector_support->ir_builder()) {
+ Set(initial_value);
+ }
+};
+
+class ScalarVariable : public LlvmVariable {
+ public:
+ ScalarVariable(VectorSupportLibrary* vector_support,
+ llvm::Value* initial_value)
+ : LlvmVariable(vector_support->scalar_type(),
+ vector_support->ir_builder()) {
+ Set(initial_value);
+ }
+};
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index d4d35da9d6..06f43bd3cb 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -68,26 +68,6 @@ LocalService::LocalService(const ServiceOptions& options,
std::unique_ptr<Backend> execute_backend)
: Service(options, std::move(execute_backend)) {}
-namespace {
-// Returns the space required to allocate a shape. If
-// allocate_space_for_deep_copy the space includes all sub-buffers of
-// a tuple.
-int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy,
- TransferManager* transfer_manager) {
- int64 size = 0;
- // TODO(b/33492279) remove once no devices represent result tuples as
- // contiguous buffers.
- if (allocate_space_for_deep_copy) {
- ShapeUtil::ForEachSubshape(
- shape, [&size, transfer_manager](const Shape& subshape,
- const ShapeIndex& /*index*/) {
- size += transfer_manager->GetByteSizeRequirement(subshape);
- });
- }
- return size;
-}
-} // namespace
-
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
const ComputationHandle& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
index b92017c6cb..6aca6ba385 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
@@ -23,6 +23,23 @@ limitations under the License.
namespace xla {
+namespace {
+
+// Gather fusion instructions from 'instruction' into 'fusion_instructions'.
+void GatherFusionInstructions(
+ HloInstruction* instruction,
+ std::vector<HloInstruction*>* fusion_instructions) {
+ CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
+ for (auto* fused : instruction->fused_instructions()) {
+ if (fused->opcode() == HloOpcode::kFusion) {
+ GatherFusionInstructions(fused, fusion_instructions);
+ }
+ }
+ fusion_instructions->push_back(instruction);
+}
+
+} // namespace
+
/* static */ StatusOr<std::unique_ptr<LogicalBufferAnalysis>>
LogicalBufferAnalysis::Run(const HloModule* module) {
std::unique_ptr<LogicalBufferAnalysis> analysis(
@@ -41,15 +58,19 @@ Status LogicalBufferAnalysis::Analyze() {
// We filter out fusion computations, and get to them through fusion
// instructions. This is because it's possible to have orphaned (unreachable)
// fusion computations, and we don't want to try to assign buffers to those.
+ std::vector<HloInstruction*> fusion_instructions;
for (auto* computation : module_->MakeNonfusionComputations()) {
TF_RETURN_IF_ERROR(computation->Accept(this));
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() != HloOpcode::kFusion) {
continue;
}
- TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
+ GatherFusionInstructions(instruction, &fusion_instructions);
}
}
+ for (auto* instruction : fusion_instructions) {
+ TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
+ }
return Status::OK();
}
@@ -104,6 +125,21 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) {
return Status::OK();
}
+Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) {
+ // RecvDone doesn't create a new buffer but rather aliases its input (Recv)
+ // tuple element at {0} to its output.
+ return Status::OK();
+}
+
+Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) {
+ // Send creates new buffers for the top-level tuple and the context (tuple
+ // element at {1}). Tuple element at {0} is an alias of the Send operand, so
+ // we don't need to create a new Logical Buffer for that.
+ NewLogicalBuffer(send, /*index=*/{});
+ NewLogicalBuffer(send, /*index=*/{1});
+ return Status::OK();
+}
+
Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) {
// A Tuple instruction only creates the top-level buffer.
NewLogicalBuffer(tuple, /*index=*/{});
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h
index a82e83ec5c..598d08b720 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h
@@ -60,6 +60,8 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault {
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleCopy(HloInstruction* copy) override;
+ Status HandleRecvDone(HloInstruction* recv_done) override;
+ Status HandleSend(HloInstruction* send) override;
Status HandleSelect(HloInstruction* select) override;
// A map from the buffer ID to the logical buffer
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 6646be2e9a..47f4f0ade5 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -272,8 +272,6 @@ class Service : public ServiceInterface {
// Create a Hlo module config for the given program shape and arguments.
// execution_options is optional; if not given a default is used.
- // has_hybrid_result is used to initialize the same-named field in
- // HloModuleConfig -- see that class for documentation.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 791d17365b..dcd726f22c 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -770,8 +771,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of binary operation"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of binary operation"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
+ lhs, tensorflow::strings::StrCat("lhs of binary operation ",
+ BinaryOperation_Name(operation))));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
+ rhs, tensorflow::strings::StrCat("rhs of binary operation ",
+ BinaryOperation_Name(operation))));
switch (operation) {
case BINOP_DOT:
return InferDotOpShape(lhs, rhs);
@@ -1943,7 +1948,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
!std::is_permutation(dimensions.begin(), dimensions.end(),
indices.begin())) {
return InvalidArgument(
- "Reshape dimensions not a permutation of the operand dimensions.");
+ "Reshape dimensions [%s] are not a permutation of the operand "
+ "dimensions (operand shape is %s).",
+ tensorflow::str_util::Join(dimensions, ",").c_str(),
+ ShapeUtil::HumanString(operand).c_str());
}
return inferred_shape;
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index a2a442eb1a..a57ebf59e7 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -63,6 +63,14 @@ void ShapedBuffer::clear() {
}
}
+void ShapedBuffer::AddBufferAtIndex(
+ const perftools::gputools::DeviceMemoryBase& buffer,
+ const ShapeIndex& shape_index) {
+ *mutable_shape_index_to_buffer_entry()->mutable_element(shape_index) =
+ buffers().size();
+ mutable_buffers()->push_back(buffer);
+}
+
const se::DeviceMemoryBase& ShapedBuffer::buffer(
const ShapeIndex& index) const {
return buffers_[shape_index_to_buffer_entry_.element(index)];
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h
index e5ea06fb13..b440948700 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.h
+++ b/tensorflow/compiler/xla/service/shaped_buffer.h
@@ -75,6 +75,10 @@ class ShapedBuffer {
// Set all device memory pointers in the object to null.
void clear();
+ // Adds a new buffer at the given shape index.
+ void AddBufferAtIndex(const perftools::gputools::DeviceMemoryBase& buffer,
+ const ShapeIndex& shape_index);
+
protected:
// The shape of the device buffer with layout.
const Shape shape_;
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index df537bd7c1..0c84856647 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -120,6 +120,23 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index,
tree_.mutable_element(index)->tuple_sources.insert(tuple);
}
+namespace {
+
+// Gather fusion instructions from 'instruction' into 'fusion_instructions'.
+void GatherFusionInstructions(
+ HloInstruction* instruction,
+ std::vector<HloInstruction*>* fusion_instructions) {
+ CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
+ for (auto* fused : instruction->fused_instructions()) {
+ if (fused->opcode() == HloOpcode::kFusion) {
+ GatherFusionInstructions(fused, fusion_instructions);
+ }
+ }
+ fusion_instructions->push_back(instruction);
+}
+
+} // namespace
+
/* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
TuplePointsToAnalysis::Run(const HloModule* module) {
auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module);
@@ -137,20 +154,23 @@ Status TuplePointsToAnalysis::Analyze() {
logical_buffer_aliases_.resize(
logical_buffer_analysis_->num_logical_buffers());
+ std::vector<HloInstruction*> fusion_instructions;
for (auto* computation : module_->MakeNonfusionComputations()) {
TF_RETURN_IF_ERROR(computation->Accept(this));
TF_RETURN_IF_ERROR(
PopulateDefinedBuffersAndAliases(computation->instructions()));
- // Run points-to analysis on fusion instructions in 'computation'.
for (auto* instruction : computation->instructions()) {
- if (instruction->opcode() != HloOpcode::kFusion) {
- continue;
+ if (instruction->opcode() == HloOpcode::kFusion) {
+ GatherFusionInstructions(instruction, &fusion_instructions);
}
- TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
- TF_RETURN_IF_ERROR(
- PopulateDefinedBuffersAndAliases(instruction->fused_instructions()));
}
}
+ // Run points-to analysis on fusion instructions in 'computation'.
+ for (auto* instruction : fusion_instructions) {
+ TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
+ TF_RETURN_IF_ERROR(
+ PopulateDefinedBuffersAndAliases(instruction->fused_instructions()));
+ }
XLA_VLOG_LINES(3, ToString());
@@ -253,6 +273,64 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
return Status::OK();
}
+Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
+ // RecvDone aliases its input (Recv) tuple element {0} to its output.
+ PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
+ const PointsToSet& operand_points_to_set =
+ GetPointsToSet(recv_done->operand(0));
+
+ // Recursively copy the points to set of the operand tuple {0}.
+ points_to_set.ForEachMutableElement(
+ [this, &points_to_set, &operand_points_to_set](
+ const ShapeIndex& index, PointsToSet::BufferList* buffers) {
+ ShapeIndex src_index({0});
+ for (auto element : index) {
+ src_index.push_back(element);
+ }
+ *buffers = operand_points_to_set.element(src_index);
+ for (auto& tuple_source :
+ operand_points_to_set.tuple_sources(src_index)) {
+ points_to_set.add_tuple_source(index, tuple_source);
+ }
+ });
+ return Status::OK();
+}
+
+Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
+ // Send creates a tuple of {aliased operand, U32 context}.
+ PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
+
+ // Creates the points to set for the tuple and its element at {1}.
+ auto top_buffer = points_to_set.mutable_element(ShapeIndex({}));
+ top_buffer->push_back(
+ &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({})));
+ points_to_set.add_tuple_source({}, send);
+
+ auto context_buffer = points_to_set.mutable_element(ShapeIndex({1}));
+ context_buffer->push_back(
+ &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
+
+ // Recursively copy the points to set of the operand to output tuple {0}.
+ const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
+ operand_points_to_set.ForEachElement(
+ [&points_to_set, &operand_points_to_set](
+ const ShapeIndex& src_index,
+ const PointsToSet::BufferList& points_to) {
+ ShapeIndex target_index({0});
+ for (auto element : src_index) {
+ target_index.push_back(element);
+ }
+ *points_to_set.mutable_element(target_index) = points_to;
+
+ for (HloInstruction* tuple :
+ operand_points_to_set.tuple_sources(src_index)) {
+ points_to_set.add_tuple_source(target_index, tuple);
+ }
+ });
+
+ return Status::OK();
+}
+
Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands());
PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index e6157a1ed1..8928de107e 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -251,6 +251,8 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleCopy(HloInstruction* copy) override;
+ Status HandleRecvDone(HloInstruction* recv_done) override;
+ Status HandleSend(HloInstruction* send) override;
Status HandleSelect(HloInstruction* select) override;
string ToString() const;
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 694ed57fa2..dec446d4da 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -313,6 +313,51 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) {
{constant1, constant2, copy});
}
+TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
+ // Send forwards its operand to the output tuple at {0}.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto send = builder.AddInstruction(
+ HloInstruction::CreateSend(constant, /*channel_id=*/0));
+ auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send).IsAmbiguous());
+ EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send).IsDistinct());
+ EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send_done).IsAmbiguous());
+ EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send_done).IsDistinct());
+
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(send).element({}), {send});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(send).element({0}), {constant});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(send_done).CreateFlattenedSet(),
+ {send_done});
+ ExpectHasBufferAliases(constant, {}, {{constant, {}}, {send, {0}}});
+}
+
+TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
+ // RecvDone forwards its operand tuple element at {0} to the output.
+ auto builder = HloComputation::Builder(TestName());
+ auto recv = builder.AddInstruction(HloInstruction::CreateRecv(
+ ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0));
+ auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv).IsAmbiguous());
+ EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv).IsDistinct());
+ EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv_done).IsAmbiguous());
+ EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv_done).IsDistinct());
+
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(recv).element({}), {recv});
+ ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}});
+}
+
TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
// Select from two different tuples. This should create an ambiguous points to
// set containing the union of both sides.
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index e9d182509b..8d5bb08e51 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -2927,8 +2927,9 @@ void ComputationLowerer::Visit(
case OpRequest::kRecvRequest: {
const RecvRequest& recv_request = request.request().recv_request();
- hlo_instruction = add_instruction(HloInstruction::CreateRecv(
+ HloInstruction* recv = add_instruction(HloInstruction::CreateRecv(
request.output_shape(), recv_request.channel_handle().handle()));
+ hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv));
break;
}
@@ -3120,8 +3121,9 @@ void ComputationLowerer::Visit(
case OpRequest::kSendRequest: {
const SendRequest& send_request = request.request().send_request();
HloInstruction* operand = lookup_instruction(send_request.operand());
- hlo_instruction = add_instruction(HloInstruction::CreateSend(
+ HloInstruction* send = add_instruction(HloInstruction::CreateSend(
operand, send_request.channel_handle().handle()));
+ hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send));
break;
}
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index 65734f91bc..2fac914892 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -58,7 +58,9 @@ static bool ContainsSendOrRecv(const HloComputation* comp) {
static bool IsOrContainsSendOrRecv(const HloInstruction* instr) {
if (instr->opcode() == HloOpcode::kSend ||
- instr->opcode() == HloOpcode::kRecv) {
+ instr->opcode() == HloOpcode::kSendDone ||
+ instr->opcode() == HloOpcode::kRecv ||
+ instr->opcode() == HloOpcode::kRecvDone) {
return true;
}
for (const auto& subcomp : instr->called_computations()) {
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 8e1a2dcde1..d99b31dc00 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -144,10 +144,11 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
- while_body->AddInstruction(HloInstruction::CreateSend(
+ auto* send = while_body->AddInstruction(HloInstruction::CreateSend(
while_body->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
/*channel_id=*/0));
+ while_body->AddInstruction(HloInstruction::CreateSendDone(send));
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
}
@@ -156,9 +157,10 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
- while_body->AddInstruction(
+ auto* recv = while_body->AddInstruction(
HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}),
/*channel_id=*/0));
+ while_body->AddInstruction(HloInstruction::CreateRecvDone(recv));
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
}
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 64a36471b9..bf8d190150 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -116,6 +116,7 @@ class ShapeTree {
ShapeTree(const Shape* shape, const T& init_value);
ShapeTree(const ShapeTree& other) { *this = other; }
+ ShapeTree(ShapeTree&&) = default;
ShapeTree& operator=(const ShapeTree& other) {
root_ = other.root_;
@@ -132,6 +133,8 @@ class ShapeTree {
return *this;
}
+ ShapeTree& operator=(ShapeTree&& other) = default;
+
// Returns the data element associated with the array in the shape at the
// given index (see ShapeUtil::GetSubshape for how indexes are defined).
const T& element(const ShapeIndex& index) const;
@@ -152,28 +155,57 @@ class ShapeTree {
using const_iterator = ShapeTreeIterator<T, /*is_const=*/true>;
// begin/end for iterating over all nodes.
- iterator begin() { return iterator(&root_, /*iterate_leaves_only=*/false); }
- iterator end() { return iterator(nullptr, /*iterate_leaves_only=*/false); }
+ iterator begin() {
+ return iterator(&root_, /*iterate_leaves_only=*/false,
+ /*reverse=*/false);
+ }
+ iterator end() {
+ return iterator(nullptr, /*iterate_leaves_only=*/false,
+ /*reverse=*/false);
+ }
const_iterator begin() const {
- return const_iterator(&root_, /*iterate_leaves_only=*/false);
+ return const_iterator(&root_, /*iterate_leaves_only=*/false,
+ /*reverse=*/false);
}
const_iterator end() const {
- return const_iterator(nullptr, /*iterate_leaves_only=*/false);
+ return const_iterator(nullptr, /*iterate_leaves_only=*/false,
+ /*reverse=*/false);
+ }
+
+ // rbegin/rend for iterating over all nodes in reverse.
+ iterator rbegin() {
+ return iterator(&root_, /*iterate_leaves_only=*/false,
+ /*reverse=*/true);
+ }
+ iterator rend() {
+ return iterator(nullptr, /*iterate_leaves_only=*/false,
+ /*reverse=*/true);
+ }
+ const_iterator rbegin() const {
+ return const_iterator(&root_, /*iterate_leaves_only=*/false,
+ /*reverse=*/true);
+ }
+ const_iterator rend() const {
+ return const_iterator(nullptr, /*iterate_leaves_only=*/false,
+ /*reverse=*/true);
}
// leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no
// children).
iterator leaf_begin() {
- return iterator(&root_, /*iterate_leaves_only=*/true);
+ return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/false);
}
iterator leaf_end() {
- return iterator(nullptr, /*iterate_leaves_only=*/true);
+ return iterator(nullptr, /*iterate_leaves_only=*/true,
+ /*reverse=*/false);
}
const_iterator leaf_begin() const {
- return const_iterator(&root_, /*iterate_leaves_only=*/true);
+ return const_iterator(&root_, /*iterate_leaves_only=*/true,
+ /*reverse=*/false);
}
const_iterator leaf_end() const {
- return const_iterator(nullptr, /*iterate_leaves_only=*/true);
+ return const_iterator(nullptr, /*iterate_leaves_only=*/true,
+ /*reverse=*/false);
}
// range-based iterator for leaf_begin()/leaf_end().
tensorflow::gtl::iterator_range<iterator> leaves() {
@@ -183,6 +215,22 @@ class ShapeTree {
return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
}
+ iterator leaf_rbegin() {
+ return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/true);
+ }
+ iterator leaf_rend() {
+ return iterator(nullptr, /*iterate_leaves_only=*/true,
+ /*reverse=*/true);
+ }
+ const_iterator leaf_rbegin() const {
+ return const_iterator(&root_, /*iterate_leaves_only=*/true,
+ /*reverse=*/true);
+ }
+ const_iterator leaf_rend() const {
+ return const_iterator(nullptr, /*iterate_leaves_only=*/true,
+ /*reverse=*/true);
+ }
+
// Recursively traverses the shape and calls the given function at each
// element. The function has the following arguments:
//
@@ -277,42 +325,61 @@ class ShapeTreeIterator : public std::iterator<std::forward_iterator_tag,
// Construct an iterator pointing at node. Node must either be the tree root
// or nullptr (which is equivalent to end() and should not be dereferenced or
// incremented). If iterate_leaves_only is true, the iterator will not include
- // interior tree nodes, only leaves.
- ShapeTreeIterator(NodeType* node, bool iterate_leaves_only)
- : node_(node), iterate_leaves_only_(iterate_leaves_only) {
- if (node_ && !node_->children.empty() && iterate_leaves_only) {
- ++*this;
+ // interior tree nodes, only leaves. If reverse is true, the iterator will
+ // visit nodes in the reverse of pre-order traversal.
+ ShapeTreeIterator(NodeType* node, bool iterate_leaves_only, bool reverse)
+ : node_(node),
+ iterate_leaves_only_(iterate_leaves_only),
+ reverse_(reverse) {
+ if (node_) {
+ if (reverse_) {
+ while (!node_->children.empty()) {
+ const int child_index = node_->children.size() - 1;
+ stack_.push_back({node_, child_index});
+ node_ = node_->children[child_index].get();
+ }
+ } else {
+ if (!node_->children.empty() && iterate_leaves_only) {
+ ++*this;
+ }
+ }
}
}
ShapeTreeIterator(const ShapeTreeIterator& other)
: node_(other.node_),
stack_(other.stack_),
- iterate_leaves_only_(other.iterate_leaves_only_) {}
+ iterate_leaves_only_(other.iterate_leaves_only_),
+ reverse_(other.reverse_) {}
ShapeTreeIterator& operator++() {
CHECK_NE(nullptr, node_) << "walking off the end() of an iterator!";
- // We're doing a pre-order walk, so if our current node has children take
- // the first child.
- if (!node_->children.empty()) {
- stack_.push_back({node_, /*child-index=*/0});
- node_ = node_->children[0].get();
- if (node_->children.empty() || !iterate_leaves_only_) {
- return *this;
- } else {
- // This is a non-leaf; tail-recurse.
- return ++(*this);
+ if (reverse_) {
+ while (!stack_.empty()) {
+ node_ = stack_.back().first;
+ int64 next_child_index = stack_.back().second - 1;
+ stack_.pop_back();
+ if (next_child_index < 0) {
+ if (!iterate_leaves_only_) {
+ // All children are visited, yield <node_>.
+ return *this;
+ }
+ } else {
+ stack_.push_back({node_, next_child_index});
+ node_ = node_->children[next_child_index].get();
+ while (!node_->children.empty()) {
+ const int child_index = node_->children.size() - 1;
+ stack_.push_back({node_, child_index});
+ node_ = node_->children[child_index].get();
+ }
+ return *this;
+ }
}
- }
- // Otherwise we are currently at a leaf. Walk back up until a node contains
- // a child we haven't visited yet.
- while (!stack_.empty()) {
- node_ = stack_.back().first;
- int64 next_child_index = stack_.back().second + 1;
- stack_.pop_back();
- if (node_->children.size() > next_child_index) {
- stack_.push_back({node_, next_child_index});
- node_ = node_->children[next_child_index].get();
-
+ } else {
+ // We're doing a pre-order walk, so if our current node has children take
+ // the first child.
+ if (!node_->children.empty()) {
+ stack_.push_back({node_, /*child-index=*/0});
+ node_ = node_->children[0].get();
if (node_->children.empty() || !iterate_leaves_only_) {
return *this;
} else {
@@ -320,6 +387,24 @@ class ShapeTreeIterator : public std::iterator<std::forward_iterator_tag,
return ++(*this);
}
}
+ // Otherwise we are currently at a leaf. Walk back up until a node
+ // contains a child we haven't visited yet.
+ while (!stack_.empty()) {
+ node_ = stack_.back().first;
+ int64 next_child_index = stack_.back().second + 1;
+ stack_.pop_back();
+ if (node_->children.size() > next_child_index) {
+ stack_.push_back({node_, next_child_index});
+ node_ = node_->children[next_child_index].get();
+
+ if (node_->children.empty() || !iterate_leaves_only_) {
+ return *this;
+ } else {
+ // This is a non-leaf; tail-recurse.
+ return ++(*this);
+ }
+ }
+ }
}
// We've walked off the end of the tree. Set node_ to nullptr to signify
// end().
@@ -361,6 +446,8 @@ class ShapeTreeIterator : public std::iterator<std::forward_iterator_tag,
std::vector<std::pair<NodeType*, int64>> stack_;
// True if we should not include interior nodes in our walk.
bool iterate_leaves_only_;
+ // True if we should yield the reverse of the pre-order traversal.
+ bool reverse_;
// Placeholder for the current value. Ideally this wouldn't exist and would
// just be an rvalue, but operator -> needs to return a pointer to something.
// We cannot just use a plain old value_type as it contains a reference so
diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc
index 7b4b5cb0fb..4b6ab77281 100644
--- a/tensorflow/compiler/xla/shape_tree_test.cc
+++ b/tensorflow/compiler/xla/shape_tree_test.cc
@@ -456,6 +456,26 @@ TEST_F(ShapeTreeTest, IterateOrder) {
{2, 1}}));
}
+TEST_F(ShapeTreeTest, ReverseIterateOrder) {
+ ShapeTree<int> t(nested_tuple_shape_, 42);
+ std::vector<ShapeIndex> v;
+ for (auto it = t.rbegin(); it != t.rend(); ++it) {
+ v.push_back(it->first);
+ }
+ EXPECT_EQ(v, (std::vector<ShapeIndex>{
+ {2, 1},
+ {2, 0, 1},
+ {2, 0, 0},
+ {2, 0},
+ {2},
+ {1, 1},
+ {1, 0},
+ {1},
+ {0},
+ {},
+ }));
+}
+
TEST_F(ShapeTreeTest, IterateOrderLeaves) {
ShapeTree<int> t(nested_tuple_shape_, 42);
std::vector<ShapeIndex> v;
@@ -466,5 +486,21 @@ TEST_F(ShapeTreeTest, IterateOrderLeaves) {
{0}, {1, 0}, {1, 1}, {2, 0, 0}, {2, 0, 1}, {2, 1}}));
}
+TEST_F(ShapeTreeTest, ReverseIterateOrderLeaves) {
+ ShapeTree<int> t(nested_tuple_shape_, 42);
+ std::vector<ShapeIndex> v;
+ for (auto it = t.leaf_rbegin(); it != t.leaf_rend(); ++it) {
+ v.push_back(it->first);
+ }
+ EXPECT_EQ(v, (std::vector<ShapeIndex>{
+ {2, 1},
+ {2, 0, 1},
+ {2, 0, 0},
+ {1, 1},
+ {1, 0},
+ {0},
+ }));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index b5eb81dfc6..4d0bafa908 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -263,6 +263,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
case S32:
case S64:
case F16:
+ case BF16:
case F32:
case F64:
return true;
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 8f8d4a73c9..82a513a65a 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -68,6 +68,9 @@ class ShapeIndex {
const int64* data() const { return indices_.data(); }
+ int64 back() const { return indices_.back(); }
+ int64& back() { return indices_.back(); }
+
const int64& operator[](size_t i) const { return indices_[i]; }
int64& operator[](size_t i) { return indices_[i]; }
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 4e1be24b61..3e62481629 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -61,13 +61,14 @@ generate_backend_test_macros()
cc_library(
name = "test_utils",
- testonly = True,
+ srcs = ["test_utils.cc"],
hdrs = ["test_utils.h"],
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
],
)
@@ -1343,22 +1344,23 @@ xla_test(
],
)
-xla_test(
+tf_cc_test(
name = "llvm_compiler_test",
srcs = ["llvm_compiler_test.cc"],
- backends = [
- "cpu",
- "gpu",
- "cpu_parallel",
- ],
+ tags = ["requires-gpu-sm35"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/service:backend",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:llvm_compiler",
- "//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/compiler/xla/tests:literal_test_util",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/service:platform_util",
+ "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+ "//tensorflow/compiler/xla/service/gpu:gpu_compiler",
"//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/stream_executor",
"@llvm//:core",
],
)
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 065bce7e31..ef54714e46 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -346,6 +346,60 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
LiteralTestUtil::ExpectNearTuple(expected, *actual, error);
}
+void ClientLibraryTestBase::ComputeAndCompare(
+ ComputationBuilder* builder, const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<Literal> arguments) {
+ auto status_or_data = ComputeValueAndReference(builder, operand, arguments);
+ EXPECT_IS_OK(status_or_data);
+ if (!status_or_data.ok()) {
+ return;
+ }
+ std::unique_ptr<Literal> reference, result;
+ std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
+ LiteralTestUtil::ExpectEqual(*reference, *result);
+}
+
+void ClientLibraryTestBase::ComputeAndCompare(
+ ComputationBuilder* builder, const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<Literal> arguments, ErrorSpec error) {
+ auto status_or_data = ComputeValueAndReference(builder, operand, arguments);
+ EXPECT_IS_OK(status_or_data);
+ if (!status_or_data.ok()) {
+ return;
+ }
+ std::unique_ptr<Literal> reference, result;
+ std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
+ LiteralTestUtil::ExpectNear(*reference, *result, error);
+}
+
+StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
+ClientLibraryTestBase::ComputeValueAndReference(
+ ComputationBuilder* builder, const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<Literal> arguments) {
+ // Transfer the arguments to the executor service. We put the unique_ptr's
+ // into a vector to keep the data alive on the service until the end of this
+ // function.
+ std::vector<std::unique_ptr<GlobalData>> argument_data;
+ for (const auto& arg : arguments) {
+ TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg));
+ argument_data.push_back(std::move(data));
+ }
+
+ // Create raw pointers to the GlobalData for the rest of the call stack.
+ std::vector<GlobalData*> argument_data_ptr;
+ std::transform(
+ argument_data.begin(), argument_data.end(),
+ std::back_inserter(argument_data_ptr),
+ [](const std::unique_ptr<GlobalData>& data) { return data.get(); });
+
+ TF_ASSIGN_OR_RETURN(
+ auto reference,
+ builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments));
+ TF_ASSIGN_OR_RETURN(auto result,
+ ExecuteAndTransfer(builder, argument_data_ptr));
+ return std::make_pair(std::move(reference), std::move(result));
+}
+
Computation ClientLibraryTestBase::CreateScalarRelu() {
ComputationBuilder builder(client_, "relu");
auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value");
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 7cfc276ec1..b578667735 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -196,6 +196,16 @@ class ClientLibraryTestBase : public ::testing::Test {
ComputationBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec abs_error);
+ // Convenience method for running a built computation and comparing the result
+ // with the HloEvaluator.
+ void ComputeAndCompare(ComputationBuilder* builder,
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<Literal> arguments);
+ void ComputeAndCompare(ComputationBuilder* builder,
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<Literal> arguments,
+ ErrorSpec error);
+
// Create scalar operations for use in reductions.
Computation CreateScalarRelu();
Computation CreateScalarMax();
@@ -298,6 +308,13 @@ class ClientLibraryTestBase : public ::testing::Test {
const std::function<void(const Literal& actual,
const string& error_message)>& verify_output,
const Shape* output_with_layout = nullptr);
+
+ // Executes the computation and calculates the expected reference value using
+ // the HloEvaluator. Returns two literal in the order of (expected, actual).
+ StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
+ ComputeValueAndReference(ComputationBuilder* builder,
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<Literal> arguments);
};
template <typename NativeT>
@@ -469,8 +486,7 @@ template <typename NativeT>
std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
const int width, NativeT min_value, NativeT max_value, uint32 seed) {
std::vector<NativeT> result(width);
- test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
- seed);
+ PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
for (int i = 0; i < width; ++i) {
result[i] = generator.get();
}
@@ -482,8 +498,7 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
const int rows, const int cols, NativeT min_value, NativeT max_value,
uint32 seed) {
auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
- test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
- seed);
+ PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
for (int y = 0; y < rows; ++y) {
for (int x = 0; x < cols; ++x) {
(*result)(y, x) = generator.get();
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index 0853feeebd..183bcf1dd3 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -54,8 +54,8 @@ TEST_F(ClientTest, ExecuteWithLayout) {
.ConsumeValueOrDie();
std::unique_ptr<Literal> expected_literal =
- test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
- transfer_layout);
+ Literal::CreateR2WithLayout<int32>(
+ {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
auto computed = client_->Transfer(*data, &expected_literal->shape());
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
index 707e439245..0f780fa87e 100644
--- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -138,13 +138,13 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) {
// layouts. Use these arrays as parameters to a simple computation. If the
// layout of the array changes then computation should be recompiled (cache
// miss).
- auto rowmaj_array = test_utils::CreateR2LiteralWithLayout(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{1, 0});
+ auto rowmaj_array = Literal::CreateR2WithLayout(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
auto rowmaj_handle =
client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
- auto colmaj_array = test_utils::CreateR2LiteralWithLayout(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1});
+ auto colmaj_array = Literal::CreateR2WithLayout(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
auto colmaj_handle =
client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index d423c78476..5226a78386 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -264,8 +264,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
ASSERT_TRUE(computed.ok()) << computed.status();
std::unique_ptr<Literal> expected_literal =
- test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
- layout);
+ Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
+ LayoutUtil::MakeLayout(layout));
LiteralTestUtil::AssertEqualShapesAndLayouts(
expected_literal->shape(), computed.ValueOrDie()->shape());
LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 0cc2e5fb7e..7425f778a6 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -82,177 +82,127 @@ XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) {
ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR4FromArray4D<float>(*alhs);
auto rhs = builder.ConstantR4FromArray4D<float>(*arhs);
- builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
+ auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
- std::unique_ptr<Array4D<float>> aexpected =
- ReferenceUtil::ConvArray4D(*alhs, *arhs, {1, 1}, Padding::kValid);
-
- ComputeAndCompareR4<float>(&builder, *aexpected, {}, error_spec_);
+ ComputeAndCompare(&builder, conv, {}, error_spec_);
}
TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) {
ComputationBuilder builder(client_, TestName());
- {
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
- }
-
- Array4D<float> input(1, 1, 1, 2);
- input.FillWithYX(Array2D<float>({
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> input_data(1, 1, 1, 2);
+ input_data.FillWithYX(Array2D<float>({
{1, 2},
}));
- Array4D<float> filter(1, 1, 1, 2);
- filter.FillWithYX(Array2D<float>({
+ Array4D<float> filter_data(1, 1, 1, 2);
+ filter_data.FillWithYX(Array2D<float>({
{5, 6},
}));
- std::unique_ptr<Array4D<float>> aexpected =
- ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid);
-
- auto input_literal =
- client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
- .ConsumeValueOrDie();
- auto filter_literal =
- client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
- .ConsumeValueOrDie();
-
- ComputeAndCompareR4<float>(&builder, *aexpected,
- {input_literal.get(), filter_literal.get()},
- error_spec_);
+ ComputeAndCompare(&builder, conv,
+ {*Literal::CreateFromArray(input_data),
+ *Literal::CreateFromArray(filter_data)},
+ error_spec_);
}
// Tests valid padding for 2D convolution in raster space.
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) {
ComputationBuilder builder(client_, TestName());
- {
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
- }
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
- Array4D<float> input(1, 1, 4, 4);
+ Array4D<float> input_data(1, 1, 4, 4);
// clang-format off
- input.FillWithYX(Array2D<float>({
+ input_data.FillWithYX(Array2D<float>({
{1, 2, 3, 4 },
{5, 6, 7, 8 },
{9, 10, 11, 12},
{13, 14, 15, 16},
}));
// clang-format on
- Array4D<float> filter(1, 1, 2, 2);
+ Array4D<float> filter_data(1, 1, 2, 2);
// clang-format off
- filter.FillWithYX(Array2D<float>({
+ filter_data.FillWithYX(Array2D<float>({
{5, 6},
{7, 8},
}));
// clang-format on
-
- std::unique_ptr<Array4D<float>> aexpected =
- ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid);
-
- auto input_literal =
- client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
- .ConsumeValueOrDie();
- auto filter_literal =
- client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
- .ConsumeValueOrDie();
-
- ComputeAndCompareR4<float>(&builder, *aexpected,
- {input_literal.get(), filter_literal.get()},
- error_spec_);
+ ComputeAndCompare(&builder, conv,
+ {*Literal::CreateFromArray(input_data),
+ *Literal::CreateFromArray(filter_data)},
+ error_spec_);
}
// Tests same padding for 2D convolution in raster space.
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) {
ComputationBuilder builder(client_, TestName());
- {
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- builder.Conv(input, filter, {1, 1}, Padding::kSame);
- }
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
- Array4D<float> input(1, 1, 4, 4);
+ Array4D<float> input_data(1, 1, 4, 4);
// clang-format off
- input.FillWithYX(Array2D<float>({
+ input_data.FillWithYX(Array2D<float>({
{1, 2, 3, 4 },
{5, 6, 7, 8 },
{9, 10, 11, 12},
{13, 14, 15, 16},
}));
// clang-format on
- Array4D<float> filter(1, 1, 2, 2);
+ Array4D<float> filter_data(1, 1, 2, 2);
// clang-format off
- filter.FillWithYX(Array2D<float>({
+ filter_data.FillWithYX(Array2D<float>({
{5, 6},
{7, 8},
}));
// clang-format on
-
- std::unique_ptr<Array4D<float>> aexpected =
- ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame);
-
- auto input_literal =
- client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
- .ConsumeValueOrDie();
- auto filter_literal =
- client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
- .ConsumeValueOrDie();
-
- ComputeAndCompareR4<float>(&builder, *aexpected,
- {input_literal.get(), filter_literal.get()},
- error_spec_);
+ ComputeAndCompare(&builder, conv,
+ {*Literal::CreateFromArray(input_data),
+ *Literal::CreateFromArray(filter_data)},
+ error_spec_);
}
// Tests same padding for 2D convolution in raster space with an odd sized
// kernel.
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) {
ComputationBuilder builder(client_, TestName());
- {
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- builder.Conv(input, filter, {1, 1}, Padding::kSame);
- }
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
- Array4D<float> input(1, 1, 4, 4);
+ Array4D<float> input_data(1, 1, 4, 4);
// clang-format off
- input.FillWithYX(Array2D<float>({
+ input_data.FillWithYX(Array2D<float>({
{1, 2, 3, 4 },
{5, 6, 7, 8 },
{9, 10, 11, 12},
{13, 14, 15, 16},
}));
// clang-format on
- Array4D<float> filter(1, 1, 3, 3);
+ Array4D<float> filter_data(1, 1, 3, 3);
// clang-format off
- filter.FillWithYX(Array2D<float>({
+ filter_data.FillWithYX(Array2D<float>({
{ 5, 6, 7},
{ 8, 9, 10},
{11, 12, 13},
}));
// clang-format on
-
- std::unique_ptr<Array4D<float>> aexpected =
- ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame);
-
- auto input_literal =
- client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
- .ConsumeValueOrDie();
- auto filter_literal =
- client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
- .ConsumeValueOrDie();
-
- ComputeAndCompareR4<float>(&builder, *aexpected,
- {input_literal.get(), filter_literal.get()},
- error_spec_);
+ ComputeAndCompare(&builder, conv,
+ {*Literal::CreateFromArray(input_data),
+ *Literal::CreateFromArray(filter_data)},
+ error_spec_);
}
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index cf089d748d..b72dd2707c 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -177,15 +177,15 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major,
bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 2.0}, {3.0, -4.0}},
- MinorToMajorForIsRowMajor(lhs_row_major)))
+ LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 6.0}, {7.0, -4.0}},
- MinorToMajorForIsRowMajor(rhs_row_major)))
+ LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
.ConsumeValueOrDie();
ComputationBuilder builder(client_, TestName());
@@ -277,6 +277,62 @@ XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFF) {
TestMatrixDot(260, 3, 520, false, false);
}
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x8) {
+ TestMatrixDot(1, 8, 8, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x130x8) {
+ TestMatrixDot(1, 130, 8, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x130) {
+ TestMatrixDot(1, 8, 130, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x290x130) {
+ TestMatrixDot(1, 290, 130, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_2x1x1) {
+ TestMatrixDot(2, 1, 1, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_8x8x1) {
+ TestMatrixDot(8, 8, 1, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x1x1) {
+ TestMatrixDot(16, 1, 1, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x3x1) {
+ TestMatrixDot(16, 3, 1, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_3x3x1) {
+ TestMatrixDot(3, 3, 1, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_29x29x1) {
+ TestMatrixDot(29, 29, 1, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x2) {
+ TestMatrixDot(1, 8, 2, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x2x8) {
+ TestMatrixDot(1, 2, 8, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1) {
+ TestMatrixDot(259, 258, 1, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1_FT) {
+ TestMatrixDot(259, 258, 1, false, true);
+}
+
XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) {
constexpr bool kLhsRowMajor = false;
constexpr bool kRhsRowMajor = false;
@@ -306,15 +362,15 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major,
bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}},
- MinorToMajorForIsRowMajor(lhs_row_major)))
+ LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}},
- MinorToMajorForIsRowMajor(rhs_row_major)))
+ LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
.ConsumeValueOrDie();
ComputationBuilder builder(client_, TestName());
@@ -361,6 +417,31 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) {
TestNonsquareMatrixDot<complex64>();
}
+XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
+ auto lhs_handle =
+ client_
+ ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
+ {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
+ .ConsumeValueOrDie();
+ auto rhs_handle =
+ client_
+ ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
+ {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
+ LayoutUtil::MakeLayout({1, 0})))
+ .ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto prim_type = primitive_util::NativeToPrimitiveType<complex64>();
+ auto result = builder.Dot(
+ builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"),
+ builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs"));
+
+ Array2D<complex64> expected({{30.0, -2.0}});
+
+ ComputeAndCompareR2<complex64>(
+ &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
+}
+
XLA_TEST_F(DotOperationTest, ConcurrentMatMul) {
ComputationBuilder builder(client_, TestName());
auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}});
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index 95a52ecd2f..75c9a0d3fb 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -116,16 +116,18 @@ template <typename FloatT, typename UnsignedT>
::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
+ auto lhs_double = static_cast<double>(lhs);
+ auto rhs_double = static_cast<double>(rhs);
if (ulhs != urhs) {
return ::testing::AssertionFailure() << tensorflow::strings::Printf(
"floating values are not bitwise-equal; and equality testing "
"was requested: %s=%g=%a vs %s=%g=%a",
tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs))
.c_str(),
- lhs, lhs,
+ lhs_double, lhs_double,
tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs))
.c_str(),
- rhs, rhs);
+ rhs_double, rhs_double);
}
return ::testing::AssertionSuccess();
}
@@ -149,6 +151,10 @@ template <typename NativeT>
// Specializations for floating types that do bitwise comparisons when equality
// comparison is requested.
template <>
+::testing::AssertionResult CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
+ return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
+}
+template <>
::testing::AssertionResult CompareEqual<float>(float lhs, float rhs) {
return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
}
@@ -238,6 +244,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
case U64:
match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0);
break;
+ case BF16:
+ match = ExpectLiteralsEqual<bfloat16>(expected, actual, &multi_index, 0);
+ break;
case F32:
match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
break;
diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
index 458258e7ee..70d8b764a3 100644
--- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
+++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
@@ -14,49 +14,118 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/stream_executor/stream_executor.h"
namespace xla {
namespace {
-class LLVMCompilerTest : public HloTestBase {};
-
-XLA_TEST_F(LLVMCompilerTest, CompilerHooks) {
- int pre_opt_hook_call_count = 0;
- int post_opt_hook_call_count = 0;
-
- auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module &) {
- ++pre_opt_hook_call_count;
- return Status::OK();
- };
- auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module &) {
- ++post_opt_hook_call_count;
- return Status::OK();
- };
-
- // Create HLO module, and run the compiler.
- auto builder = HloComputation::Builder(TestName());
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
-
- auto hlo_module = CreateNewModule();
- hlo_module->AddEntryComputation(builder.Build());
-
- auto compiler = static_cast<LLVMCompiler *>(backend().compiler());
- compiler->SetPreOptimizationHook(pre_opt_hook);
- compiler->SetPostOptimizationHook(post_opt_hook);
-
- ASSERT_TRUE(
- compiler
- ->Compile(std::move(hlo_module), backend().default_stream_executor())
- .ok());
-
- // Test that hooks were called.
- EXPECT_EQ(1, pre_opt_hook_call_count);
- EXPECT_EQ(1, post_opt_hook_call_count);
+class LLVMCompilerTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ Platform *platform = FindPlatform();
+ ASSERT_NE(platform, nullptr);
+
+ BackendOptions backend_options;
+ backend_options.set_platform(platform);
+ StatusOr<std::unique_ptr<Backend>> backend_or_status =
+ Backend::CreateBackend(backend_options);
+ ASSERT_IS_OK(backend_or_status.status());
+ backend_ = backend_or_status.ConsumeValueOrDie();
+ }
+
+ ~LLVMCompilerTest() override {}
+
+ protected:
+ using Platform = ::perftools::gputools::Platform;
+
+ explicit LLVMCompilerTest(string platform_name)
+ : platform_name_(std::move(platform_name)) {}
+
+ void TestCompilerHooks(LLVMCompiler *compiler) {
+ int pre_opt_hook_call_count = 0;
+ int post_opt_hook_call_count = 0;
+
+ auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module &) {
+ ++pre_opt_hook_call_count;
+ return Status::OK();
+ };
+ auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module &) {
+ ++post_opt_hook_call_count;
+ return Status::OK();
+ };
+
+ // Create HLO module, and run the compiler.
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+
+ auto hlo_module = CreateNewModule();
+ hlo_module->AddEntryComputation(builder.Build());
+
+ compiler->SetPreOptimizationHook(pre_opt_hook);
+ compiler->SetPostOptimizationHook(post_opt_hook);
+
+ ASSERT_TRUE(compiler
+ ->Compile(std::move(hlo_module),
+ backend_->default_stream_executor())
+ .ok());
+
+ // Test that hooks were called.
+ EXPECT_EQ(1, pre_opt_hook_call_count);
+ EXPECT_EQ(1, post_opt_hook_call_count);
+ }
+
+ private:
+ Platform *FindPlatform() {
+ for (Platform *platform :
+ PlatformUtil::GetSupportedPlatforms().ConsumeValueOrDie()) {
+ if (platform->Name() == platform_name_) {
+ return platform;
+ }
+ }
+ return nullptr;
+ }
+
+ string platform_name_;
+ std::unique_ptr<Backend> backend_;
+
+ static string TestName() {
+ return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+ }
+
+ static std::unique_ptr<HloModule> CreateNewModule() {
+ HloModuleConfig config;
+ config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
+ return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
+ config);
+ }
+};
+
+class CpuCompilerTest : public LLVMCompilerTest {
+ public:
+ CpuCompilerTest() : LLVMCompilerTest("Host") {}
+};
+
+class GpuCompilerTest : public LLVMCompilerTest {
+ public:
+ GpuCompilerTest() : LLVMCompilerTest("CUDA") {}
+};
+
+TEST_F(CpuCompilerTest, HooksTest) {
+ cpu::CpuCompiler compiler;
+ TestCompilerHooks(&compiler);
+}
+
+TEST_F(GpuCompilerTest, HooksTest) {
+ gpu::GpuCompiler compiler;
+ TestCompilerHooks(&compiler);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 329b53012f..a196e250d1 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -136,16 +136,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
auto computation = builder.Build().ConsumeValueOrDie();
// Create x as a col-major array.
- auto x_array = LiteralToShapedBuffer(
- *test_utils::CreateR2LiteralWithLayout({{1.0f, 2.0f}, {3.0f, 4.0f}},
- /*minor_to_major=*/{0, 1}));
+ auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(),
LayoutUtil::MakeLayout({0, 1})));
// Create y as a row-major array.
- auto y_array = LiteralToShapedBuffer(
- *test_utils::CreateR2LiteralWithLayout({{10.0f, 20.0f}, {30.0f, 40.0f}},
- /*minor_to_major=*/{1, 0}));
+ auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
+ {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(),
LayoutUtil::MakeLayout({1, 0})));
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index c11e1df0a7..d98875dbc2 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#define EIGEN_USE_THREADS
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
#include <vector>
-#define EIGEN_USE_THREADS
-
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/map_util.h"
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 2ef392508d..2b0f7e6e80 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -405,13 +405,13 @@ TEST_F(MapTest, MapBinaryAdder) {
// for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest, AddWithMixedLayouts) {
ComputationBuilder builder(client_, TestName());
- std::unique_ptr<Literal> param0_literal =
- test_utils::CreateR2LiteralWithLayout({{1, 2}, {3, 4}}, {1, 0});
+ std::unique_ptr<Literal> param0_literal = Literal::CreateR2WithLayout(
+ {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
- test_utils::CreateR2LiteralWithLayout({{10, 20}, {30, 40}}, {0, 1});
+ std::unique_ptr<Literal> param1_literal = Literal::CreateR2WithLayout(
+ {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index 72c68f24a0..d235b9a158 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -431,8 +431,9 @@ XLA_TEST_F(ReshapeTest, ToScalar) {
XLA_TEST_F(ReshapeTest, BadDimensions) {
ComputationBuilder b(client_, TestName());
b.Reshape(b.ConstantR1<int32>({1}), {}, {});
- EXPECT_THAT(ExecuteToString(&b, {}),
- ::testing::HasSubstr("dimensions not a permutation"));
+ EXPECT_THAT(
+ ExecuteToString(&b, {}),
+ ::testing::HasSubstr("not a permutation of the operand dimensions"));
}
XLA_TEST_F(ReshapeTest, BadNewSizes) {
diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h
index 3878ac1013..bea0b5ef92 100644
--- a/tensorflow/compiler/xla/tests/test_macros.h
+++ b/tensorflow/compiler/xla/tests/test_macros.h
@@ -96,7 +96,8 @@ string PrependDisabledIfIndicated(const string& test_case_name,
test_name)::test_info_ = \
::testing::internal::MakeAndRegisterTestInfo( \
#test_case_name, \
- PrependDisabledIfIndicated(#test_case_name, #test_name).c_str(), \
+ ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
+ .c_str(), \
nullptr, nullptr, \
::testing::internal::CodeLocation(__FILE__, __LINE__), (parent_id), \
parent_class::SetUpTestCase, parent_class::TearDownTestCase, \
@@ -135,7 +136,8 @@ string PrependDisabledIfIndicated(const string& test_case_name,
::testing::internal::CodeLocation(__FILE__, __LINE__)) \
->AddTestPattern( \
#test_case_name, \
- PrependDisabledIfIndicated(#test_case_name, #test_name).c_str(), \
+ ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
+ .c_str(), \
new ::testing::internal::TestMetaFactory<GTEST_TEST_CLASS_NAME_( \
test_case_name, test_name)>()); \
return 0; \
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
new file mode 100644
index 0000000000..cdd3d66bbb
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -0,0 +1,120 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+
+#include "tensorflow/compiler/xla/primitive_util.h"
+
+namespace xla {
+
+namespace {
+
+template <typename FloatT>
+void PopulateWithRandomFloatingPointData(Literal* literal) {
+ CHECK_EQ(literal->shape().element_type(),
+ primitive_util::NativeToPrimitiveType<FloatT>());
+ std::minstd_rand0 engine;
+ std::uniform_real_distribution<FloatT> generator(0.0f, 1.0f);
+ TF_CHECK_OK(literal->Populate<FloatT>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return generator(engine);
+ }));
+}
+
+template <typename IntT>
+void PopulateWithRandomIntegralData(Literal* literal) {
+ CHECK_EQ(literal->shape().element_type(),
+ primitive_util::NativeToPrimitiveType<IntT>());
+ std::minstd_rand0 engine;
+ std::uniform_int_distribution<IntT> generator(
+ std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
+ TF_CHECK_OK(literal->Populate<IntT>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return generator(engine);
+ }));
+}
+
+} // namespace
+
+StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
+ if (ShapeUtil::IsTuple(shape)) {
+ std::vector<std::unique_ptr<Literal>> elements;
+ for (const Shape& element_shape : shape.tuple_shapes()) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
+ MakeFakeLiteral(element_shape));
+ elements.push_back(std::move(element));
+ }
+ return Literal::MakeTupleOwned(std::move(elements));
+ }
+ std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
+ switch (shape.element_type()) {
+ case F32:
+ PopulateWithRandomFloatingPointData<float>(literal.get());
+ break;
+ case F64:
+ PopulateWithRandomFloatingPointData<double>(literal.get());
+ break;
+ case S8:
+ PopulateWithRandomIntegralData<int8>(literal.get());
+ break;
+ case U8:
+ PopulateWithRandomIntegralData<uint8>(literal.get());
+ break;
+ case S16:
+ PopulateWithRandomIntegralData<int16>(literal.get());
+ break;
+ case U16:
+ PopulateWithRandomIntegralData<uint16>(literal.get());
+ break;
+ case S32:
+ PopulateWithRandomIntegralData<int32>(literal.get());
+ break;
+ case U32:
+ PopulateWithRandomIntegralData<uint32>(literal.get());
+ break;
+ case S64:
+ PopulateWithRandomIntegralData<int64>(literal.get());
+ break;
+ case U64:
+ PopulateWithRandomIntegralData<uint64>(literal.get());
+ break;
+ case PRED: {
+ std::uniform_int_distribution<int> generator(0, 1);
+ std::minstd_rand0 engine;
+ TF_CHECK_OK(literal->Populate<bool>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return generator(engine);
+ }));
+ break;
+ }
+ default:
+ return Unimplemented("Unsupported type for fake literal generation: %s",
+ ShapeUtil::HumanString(shape).c_str());
+ }
+ return std::move(literal);
+}
+
+StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
+ const HloModule& module) {
+ std::vector<std::unique_ptr<Literal>> arguments;
+ for (const ShapeLayout& shape_layout :
+ module.config().entry_computation_layout().parameter_layouts()) {
+ TF_ASSIGN_OR_RETURN(auto literal, MakeFakeLiteral(shape_layout.shape()));
+ arguments.push_back(std::move(literal));
+ }
+ return std::move(arguments);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index f3a522b05e..12d5255fce 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -23,12 +23,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
-namespace test_utils {
// A class which generates pseudorandom numbers of a given type within a given
// range. Not cryptographically secure and likely not perfectly evenly
@@ -53,63 +53,15 @@ class PseudorandomGenerator {
std::mt19937 generator_;
};
-// Convenience function for creating a rank-2 array with arbitrary layout.
-template <typename NativeT>
-std::unique_ptr<Literal> CreateR2LiteralWithLayout(
- std::initializer_list<std::initializer_list<NativeT>> values,
- tensorflow::gtl::ArraySlice<int64> minor_to_major) {
- auto literal = MakeUnique<Literal>();
- const int64 d0 = values.size();
- const int64 d1 = values.begin()->size();
- literal.get()->PopulateWithValue<NativeT>(0, {d0, d1});
- *literal->mutable_shape()->mutable_layout() =
- LayoutUtil::MakeLayout(minor_to_major);
- TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
-
- int64 dim0 = 0;
- for (auto inner_list : values) {
- int64 dim1 = 0;
- for (auto value : inner_list) {
- literal.get()->Set({dim0, dim1}, value);
- ++dim1;
- }
- ++dim0;
- }
- return literal;
-}
+// Generates fake data in a literal of the given shape, or returns an error
+// status if the element type is currently unhandled for fake data generation.
+StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
-// Convenience function for creating a rank-3 array with arbitrary layout.
-template <typename NativeT>
-std::unique_ptr<Literal> CreateR3LiteralWithLayout(
- std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
- values,
- tensorflow::gtl::ArraySlice<int64> minor_to_major) {
- auto literal = MakeUnique<Literal>();
- const int64 d0 = values.size();
- const int64 d1 = values.begin()->size();
- const int64 d2 = values.begin()->begin()->size();
- literal.get()->PopulateWithValue<NativeT>(0, {d0, d1, d2});
- *literal->mutable_shape()->mutable_layout() =
- LayoutUtil::MakeLayout(minor_to_major);
- TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
-
- int64 dim0 = 0;
- for (auto inner_list : values) {
- int64 dim1 = 0;
- for (auto inner_inner_list : inner_list) {
- int64 dim2 = 0;
- for (auto value : inner_inner_list) {
- literal.get()->Set({dim0, dim1, dim2}, value);
- ++dim2;
- }
- ++dim1;
- }
- ++dim0;
- }
- return literal;
-}
+// Generates a vector of arguments containing fake data. The number, shape and
+// layout of the arguments is appropriate for given HLO module.
+StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
+ const HloModule& module);
-} // namespace test_utils
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index 759921dce5..091fa0c3ec 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -88,6 +88,7 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:testing",
"//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md
index 2c864d77a2..b768b94e77 100644
--- a/tensorflow/compiler/xla/tools/parser/README.md
+++ b/tensorflow/compiler/xla/tools/parser/README.md
@@ -43,14 +43,22 @@ operand
: shape name
;
-extra_attributes
+attributes
: /*empty*/
- | ',' extra_attribute
- | ',' extra_attribute extra_attributes
+ | ',' attribute
+ | ',' attribute attributes
;
-extra_attribute
+attribute
: attribute_name attribute_value
;
+attribute_value
+ : kInt
+ | kName
+ | [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,} /*dim_labels_pattern*/
+ | [0-9]+(x[0-9]+)+ /*dxd_pattern*/
+ | [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* /*pad_pattern*/
+ | '{' sub_attributes '}'
+ ;
param_list
: '(' param_list1 ')'
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
index d104ff3460..098879155a 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
@@ -122,7 +123,7 @@ TokKind HloLexer::LexToken() {
current_ptr_++;
return TokKind::kArrow;
}
- return LexDigitOrNegative();
+ return LexNumberOrPattern();
case '=':
return TokKind::kEqual;
case ',':
@@ -145,16 +146,21 @@ TokKind HloLexer::LexToken() {
return TokKind::kRparen;
case '/':
return LexComment();
+ case '"':
+ return LexString();
}
}
}
-// Lex a shape, name, keyword, or opcode.
+// Lex a shape, name, keyword, opcode, attribute name, or the dim labels
+// pattern.
+//
// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
// keyword ::= HloModule, ENTRY, ...
// opcode ::= add, greater-than, ...
// attribute_name ::= condition, body, dimensions, ...
+// dim_labels_pattern ::= [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,}
TokKind HloLexer::LexIdentifier() {
{
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
@@ -220,6 +226,16 @@ TokKind HloLexer::LexIdentifier() {
return TokKind::kOpcode;
}
+ {
+ auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
+ static LazyRE2 dim_labels_pattern = {
+ R"([0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,})"};
+ if (RE2::Consume(&consumable, *dim_labels_pattern)) {
+ current_ptr_ = consumable.begin();
+ str_val_.assign(token_start_, current_ptr_);
+ return TokKind::kDimLabels;
+ }
+ }
current_ptr_ = token_start_ + 1;
return TokKind::kError;
}
@@ -240,15 +256,20 @@ TokKind HloLexer::LexPercent() {
return TokKind::kError;
}
-// Lex integer and floating-point values, and -inf.
-// int [-]?[0-9]+
-// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
-// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
-// negative inf -inf
-TokKind HloLexer::LexDigitOrNegative() {
+// Lex integer and floating-point values, -inf, and patterns for dim labels,
+// dxd (e.g. 1x2x3), and pad.
+//
+// fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
+// fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
+// dim_labels_pattern ::= [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,}
+// dxd_pattern ::= [0-9]+(x[0-9]+)+
+// pad_pattern ::= [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*
+// int ::= [-]?[0-9]+
+// negative inf ::= '-inf'
+TokKind HloLexer::LexNumberOrPattern() {
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
static LazyRE2 float_pattern = {
- R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"};
+ R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
if (RE2::Consume(&consumable, *float_pattern)) {
current_ptr_ = consumable.begin();
tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(),
@@ -256,6 +277,30 @@ TokKind HloLexer::LexDigitOrNegative() {
return TokKind::kDecimal;
}
+ static LazyRE2 dim_labels_pattern = {
+ R"([0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,})"};
+ static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"};
+ static LazyRE2 pad_pattern = {
+ R"([0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*)"};
+
+ if (RE2::Consume(&consumable, *dim_labels_pattern)) {
+ current_ptr_ = consumable.begin();
+ str_val_.assign(token_start_, current_ptr_);
+ return TokKind::kDimLabels;
+ }
+
+ if (RE2::Consume(&consumable, *dxd_pattern)) {
+ current_ptr_ = consumable.begin();
+ str_val_.assign(token_start_, current_ptr_);
+ return TokKind::kDxD;
+ }
+
+ if (RE2::Consume(&consumable, *pad_pattern)) {
+ current_ptr_ = consumable.begin();
+ str_val_.assign(token_start_, current_ptr_);
+ return TokKind::kPad;
+ }
+
static LazyRE2 int_pattern = {R"([-]?\d+)"};
if (RE2::Consume(&consumable, *int_pattern)) {
current_ptr_ = consumable.begin();
@@ -298,6 +343,25 @@ TokKind HloLexer::LexComment() {
return TokKind::kError;
}
+// Lexes quoted string with escaping characters. If matched, the quoted string
+// will be unescaped and stored to str_val_.
+TokKind HloLexer::LexString() {
+ auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
+ static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
+ if (RE2::Consume(&consumable, *escaping_pattern)) {
+ current_ptr_ = consumable.begin();
+ StringPiece raw =
+ StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1);
+ string error;
+ if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) {
+ LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error;
+ return TokKind::kError;
+ }
+ return TokKind::kString;
+ }
+ return TokKind::kError;
+}
+
string TokKindToString(TokKind kind) {
switch (kind) {
case TokKind::kEof:
@@ -350,6 +414,14 @@ string TokKindToString(TokKind kind) {
return "kName";
case TokKind::kAttributeName:
return "kAttributeName";
+ case TokKind::kDimLabels:
+ return "kDimLabels";
+ case TokKind::kDxD:
+ return "kDxD";
+ case TokKind::kPad:
+ return "kPad";
+ case TokKind::kString:
+ return "kString";
case TokKind::kShape:
return "kShape";
case TokKind::kOpcode:
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
index 3b9efcb92d..2236c26619 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
+++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
@@ -37,11 +37,16 @@ class HloLexer {
}
TokKind Lex() { return current_kind_ = LexToken(); }
+
TokKind GetKind() const { return current_kind_; }
string GetStrVal() const {
switch (GetKind()) {
case TokKind::kName:
case TokKind::kAttributeName:
+ case TokKind::kDimLabels:
+ case TokKind::kDxD:
+ case TokKind::kPad:
+ case TokKind::kString:
return str_val_;
default:
LOG(FATAL) << "This token does not have string value";
@@ -92,8 +97,9 @@ class HloLexer {
TokKind LexPercent();
TokKind LexShape();
TokKind LexConstant();
- TokKind LexDigitOrNegative();
+ TokKind LexNumberOrPattern();
TokKind LexComment();
+ TokKind LexString();
const tensorflow::StringPiece buf_;
const char* current_ptr_;
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 6c2e37e3b5..ac7d9ff482 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -28,6 +28,9 @@ namespace tools {
namespace {
using tensorflow::StringPiece;
+using tensorflow::gtl::optional;
+using tensorflow::str_util::Split;
+using tensorflow::str_util::SplitAndParseAsInts;
using tensorflow::strings::Printf;
using tensorflow::strings::StrAppend;
using tensorflow::strings::StrCat;
@@ -57,7 +60,6 @@ class HloParser {
bool ParseInstructionList(HloComputation::Builder* builder,
string* root_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
- bool ParseSharding(HloInstruction* instruction);
bool ParseControlPredecessors(HloInstruction* instruction);
bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
@@ -78,14 +80,96 @@ class HloParser {
bool ParseOperands(std::vector<HloInstruction*>* operands,
const int expected_size);
- template <typename T>
- bool ParseExtraAttribute(T* value, const string& expected_attribute);
- template <typename T>
- bool ParseAttributeValue(T* value);
+ // Describes the start, limit, and stride on every dimension of the operand
+ // being sliced.
+ struct SliceRanges {
+ std::vector<int64> starts;
+ std::vector<int64> limits;
+ std::vector<int64> strides;
+ };
+
+ // Types of attributes.
+ enum class AttrTy {
+ kInt64,
+ kInt32,
+ kFloat,
+ kString,
+ kBracedInt64List,
+ kHloComputation,
+ kWindow,
+ kConvolutionDimensionNumbers,
+ kSharding,
+ kInstructionList,
+ kSliceRanges,
+ kPaddingConfig,
+ kMetadata,
+ };
+
+ struct AttrConfig {
+ bool required; // whether it's required or optional
+ AttrTy attr_type; // what type it is
+ void* result; // where to store the parsed result.
+ };
+
+ // attributes ::= (',' attribute)*
+ //
+ // Parses attributes given names and configs of the attributes. Each parsed
+ // result is passed back through the result pointer in corresponding
+ // AttrConfig. Note that the result pointer must point to a optional<T> typed
+ // variable which outlives this function. Returns false on error. You should
+ // not use the any of the results if this function failed.
+ //
+ // Example usage:
+ //
+ // std::unordered_map<string, AttrConfig> attrs;
+ // optional<int64> foo;
+ // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
+ // optional<Window> bar;
+ // attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
+ // if (!ParseAttributes(attrs)) {
+ // return false; // Do not use 'foo' 'bar' if failed.
+ // }
+ // // Do something with 'bar'.
+ // if (foo) { // If attr foo is seen, do something with 'foo'. }
+ //
+ bool ParseAttributes(const std::unordered_map<string, AttrConfig>& attrs);
+
+ // sub_attributes ::= '{' (','? attribute)* '}'
+ //
+ // Usage is the same as ParseAttributes. See immediately above.
+ bool ParseSubAttributes(const std::unordered_map<string, AttrConfig>& attrs);
+
+ // Parses one attribute. If it has already been seen, return error. Returns
+ // true and adds to seen_attrs on success.
+ //
+ // Do not call this except in ParseAttributes or ParseSubAttributes.
+ bool ParseAttributeHelper(const std::unordered_map<string, AttrConfig>& attrs,
+ std::unordered_set<string>* seen_attrs);
+
+ // Parses a name and finds the corresponding hlo computation.
+ bool ParseComputationName(HloComputation** value);
+ // Parses a list of names and finds the corresponding hlo instructions.
+ bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
+ bool ParseWindow(Window* window);
+ bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
+ bool ParsePaddingConfig(PaddingConfig* padding);
+ bool ParseMetadata(OpMetadata* metadata);
+ bool ParseSharding(OpSharding* sharding);
+ bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
+
+ // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
+ bool ParseDxD(const string& name, std::vector<int64>* result);
+ // Parses window's pad sub-attriute, e.g., pad=0_0x3x3.
+ bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
+
+ bool ParseSliceRanges(SliceRanges* result);
+ bool ParseInt64List(const TokKind start, const TokKind end,
+ const TokKind delim, std::vector<int64>* result);
bool ParseParamList();
bool ParseName(string* result);
bool ParseAttributeName(string* result);
+ bool ParseString(string* result);
bool ParseShape(Shape* result);
bool ParseOpcode(HloOpcode* result);
bool ParseInt64(int64* result);
@@ -214,7 +298,7 @@ bool HloParser::ParseInstructionList(HloComputation::Builder* builder,
"expects '}' at the end of instruction list.");
}
-// instruction ::= ('ROOT')? name '=' shape opcode operands (extra_attribute)*
+// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
bool HloParser::ParseInstruction(HloComputation::Builder* builder,
string* root_name) {
string name;
@@ -230,6 +314,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (is_root) {
*root_name = name;
}
+
+ // Add optional attributes.
+ std::unordered_map<string, AttrConfig> attrs;
+ optional<OpSharding> sharding;
+ attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
+ optional<std::vector<HloInstruction*>> predecessors;
+ attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
+ &predecessors};
+ optional<OpMetadata> metadata;
+ attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
+
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
@@ -237,7 +332,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!ParseToken(TokKind::kLparen,
"expects '(' before parameter number") ||
!ParseInt64(&parameter_number) ||
- !ParseToken(TokKind::kRparen, "expects ')' after parameter number")) {
+ !ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
+ !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
@@ -249,7 +345,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!ParseToken(TokKind::kLparen,
"expects '(' before constant literal") ||
!ParseLiteral(&literal, shape) ||
- !ParseToken(TokKind::kRparen, "expects ')' after constant literal")) {
+ !ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
+ !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
@@ -275,7 +372,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kSin:
case HloOpcode::kSort:
case HloOpcode::kTanh: {
- if (!ParseOperands(&operands, /*expected_size=*/1)) {
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
@@ -305,7 +403,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical: {
- if (!ParseOperands(&operands, /*expected_size=*/2)) {
+ if (!ParseOperands(&operands, /*expected_size=*/2) ||
+ !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateBinary(
@@ -315,7 +414,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
// Ternary ops.
case HloOpcode::kClamp:
case HloOpcode::kSelect: {
- if (!ParseOperands(&operands, /*expected_size=*/3)) {
+ if (!ParseOperands(&operands, /*expected_size=*/3) ||
+ !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateTernary(
@@ -324,7 +424,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
// Other supported ops.
case HloOpcode::kConvert: {
- if (!ParseOperands(&operands, /*expected_size=*/1)) {
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
@@ -332,7 +433,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kCrossReplicaSum: {
- if (!ParseOperands(&operands, /*expected_size=*/1)) {
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
@@ -340,7 +442,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kReshape: {
- if (!ParseOperands(&operands, /*expected_size=*/1)) {
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
@@ -348,7 +451,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kTuple: {
- if (!ParseOperands(&operands)) {
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction =
@@ -356,126 +459,379 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kWhile: {
- HloComputation* condition;
- HloComputation* body;
+ optional<HloComputation*> condition;
+ optional<HloComputation*> body;
+ attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
+ &condition};
+ attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseExtraAttribute(&condition,
- /*expected_attribute=*/"condition") ||
- !ParseExtraAttribute(&body, /*expected_attribute=*/"body")) {
+ !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateWhile(
- shape, condition, body, /*init=*/operands[0]));
+ shape, *condition, *body, /*init=*/operands[0]));
break;
}
case HloOpcode::kRecv: {
- int64 channel_id;
+ optional<int64> channel_id;
+ attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands, /*expected_size=*/0) ||
- !ParseExtraAttribute(&channel_id,
- /*expected_attribute=*/"channel_id")) {
+ !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
- HloInstruction::CreateRecv(shape, channel_id));
+ HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id));
+ break;
+ }
+ case HloOpcode::kRecvDone: {
+ optional<int64> channel_id;
+ attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ if (channel_id != operands[0]->channel_id()) {
+ return false;
+ }
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0]));
break;
}
case HloOpcode::kSend: {
- int64 channel_id;
+ optional<int64> channel_id;
+ attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseExtraAttribute(&channel_id,
- /*expected_attribute=*/"channel_id")) {
+ !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
- HloInstruction::CreateSend(operands[0], channel_id));
+ HloInstruction::CreateSend(operands[0], *channel_id));
+ break;
+ }
+ case HloOpcode::kSendDone: {
+ optional<int64> channel_id;
+ attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ if (channel_id != operands[0]->channel_id()) {
+ return false;
+ }
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateSendDone(operands[0]));
break;
}
case HloOpcode::kGetTupleElement: {
- int64 index;
+ optional<int64> index;
+ attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseExtraAttribute(&index, /*expected_attribute=*/"index")) {
+ !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
- HloInstruction::CreateGetTupleElement(shape, operands[0], index));
+ HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
break;
}
case HloOpcode::kCall: {
- HloComputation* to_apply;
- if (!ParseOperands(&operands) ||
- !ParseExtraAttribute(&to_apply,
- /*expected_attribute=*/"to_apply")) {
+ optional<HloComputation*> to_apply;
+ attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
+ &to_apply};
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateCall(shape, operands, *to_apply));
+ break;
+ }
+ case HloOpcode::kReduceWindow: {
+ optional<HloComputation*> reduce_computation;
+ optional<Window> window;
+ attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window};
+ attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
+ &reduce_computation};
+ if (!ParseOperands(&operands, /*expected_size=*/2) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
+ shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
+ *reduce_computation));
+ break;
+ }
+ case HloOpcode::kConvolution: {
+ optional<Window> window;
+ optional<ConvolutionDimensionNumbers> dnums;
+ attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window};
+ attrs["dim_labels"] = {/*required=*/true,
+ AttrTy::kConvolutionDimensionNumbers, &dnums};
+ if (!ParseOperands(&operands, /*expected_size=*/2) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
+ shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums));
+ break;
+ }
+ case HloOpcode::kBroadcast: {
+ optional<std::vector<int64>> broadcast_dimensions;
+ attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &broadcast_dimensions};
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
+ shape, operands[0], *broadcast_dimensions));
+ break;
+ }
+ case HloOpcode::kConcatenate: {
+ optional<std::vector<int64>> dimensions;
+ attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &dimensions};
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
+ dimensions->size() != 1) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
+ shape, operands, dimensions->at(0)));
+ break;
+ }
+ case HloOpcode::kMap: {
+ optional<HloComputation*> to_apply;
+ attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
+ &to_apply};
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateMap(shape, operands, *to_apply));
+ break;
+ }
+ case HloOpcode::kReduce: {
+ optional<HloComputation*> reduce_computation;
+ attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
+ &reduce_computation};
+ optional<std::vector<int64>> dimensions_to_reduce;
+ attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &dimensions_to_reduce};
+ if (!ParseOperands(&operands, /*expected_size=*/2) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateReduce(
+ shape, /*operand=*/operands[0], /*init_value=*/operands[1],
+ *dimensions_to_reduce, *reduce_computation));
+ break;
+ }
+ case HloOpcode::kReverse: {
+ optional<std::vector<int64>> dimensions;
+ attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &dimensions};
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateReverse(shape, operands[0], *dimensions));
+ break;
+ }
+ case HloOpcode::kSelectAndScatter: {
+ optional<HloComputation*> select;
+ attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
+ optional<HloComputation*> scatter;
+ attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
+ optional<Window> window;
+ attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window};
+ if (!ParseOperands(&operands, /*expected_size=*/3) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
+ shape, /*operand=*/operands[0], *select, *window,
+ /*source=*/operands[1], /*init_value=*/operands[2], *scatter));
+ break;
+ }
+ case HloOpcode::kSlice: {
+ optional<SliceRanges> slice_ranges;
+ attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateSlice(
+ shape, operands[0], slice_ranges->starts, slice_ranges->limits,
+ slice_ranges->strides));
+ break;
+ }
+ case HloOpcode::kDynamicSlice: {
+ optional<std::vector<int64>> dynamic_slice_sizes;
+ attrs["dynamic_slice_sizes"] = {
+ /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
+ if (!ParseOperands(&operands, /*expected_size=*/2) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
+ shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
+ *dynamic_slice_sizes));
+ break;
+ }
+ case HloOpcode::kDynamicUpdateSlice: {
+ if (!ParseOperands(&operands, /*expected_size=*/3) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ shape, /*operand=*/operands[0], /*update=*/operands[1],
+ /*start_indices=*/operands[2]));
+ break;
+ }
+ case HloOpcode::kTranspose: {
+ optional<std::vector<int64>> dimensions;
+ attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &dimensions};
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
- HloInstruction::CreateCall(shape, operands, to_apply));
+ HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
+ break;
+ }
+ case HloOpcode::kBatchNormTraining: {
+ optional<float> epsilon;
+ attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
+ optional<int64> feature_index;
+ attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
+ &feature_index};
+ if (!ParseOperands(&operands, /*expected_size=*/3) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
+ shape, /*operand=*/operands[0], /*scale=*/operands[1],
+ /*offset=*/operands[2], *epsilon, *feature_index));
+ break;
+ }
+ case HloOpcode::kBatchNormInference: {
+ optional<float> epsilon;
+ attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
+ optional<int64> feature_index;
+ attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
+ &feature_index};
+ if (!ParseOperands(&operands, /*expected_size=*/5) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateBatchNormInference(
+ shape, /*operand=*/operands[0], /*scale=*/operands[1],
+ /*offset=*/operands[2], /*mean=*/operands[3],
+ /*variance=*/operands[4], *epsilon, *feature_index));
+ break;
+ }
+ case HloOpcode::kBatchNormGrad: {
+ optional<float> epsilon;
+ attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
+ optional<int64> feature_index;
+ attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
+ &feature_index};
+ if (!ParseOperands(&operands, /*expected_size=*/5) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
+ shape, /*operand=*/operands[0], /*scale=*/operands[1],
+ /*mean=*/operands[2], /*variance=*/operands[3],
+ /*grad_output=*/operands[4], *epsilon, *feature_index));
+ break;
+ }
+ case HloOpcode::kPad: {
+ optional<PaddingConfig> padding;
+ attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
+ if (!ParseOperands(&operands, /*expected_size=*/2) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreatePad(
+ shape, operands[0], /*padding_value=*/operands[1], *padding));
break;
}
- case HloOpcode::kBroadcast:
case HloOpcode::kCustomCall:
- case HloOpcode::kConcatenate:
case HloOpcode::kReducePrecision:
- case HloOpcode::kConvolution:
- case HloOpcode::kMap:
- case HloOpcode::kPad:
- case HloOpcode::kReduce:
- case HloOpcode::kReduceWindow:
- case HloOpcode::kSelectAndScatter:
- case HloOpcode::kReverse:
case HloOpcode::kRng:
- case HloOpcode::kSlice:
- case HloOpcode::kDynamicSlice:
- case HloOpcode::kDynamicUpdateSlice:
- case HloOpcode::kTranspose:
case HloOpcode::kFusion:
- case HloOpcode::kBatchNormTraining:
- case HloOpcode::kBatchNormInference:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
- case HloOpcode::kBatchNormGrad:
case HloOpcode::kTrace:
return TokenError(StrCat("parsing not yet implemented for op: ",
HloOpcodeString(opcode)));
}
- bool has_sharding = false;
- bool has_control = false;
- while (EatIfPresent(TokKind::kComma)) {
- string attribute_name;
- if (!ParseAttributeName(&attribute_name)) {
- return TokenError("expects ', sharding=' or ', control-predecessors='");
+ // Add common attrs (sharding, control predecessors) to the instruction, if
+ // they were seen.
+ if (sharding) {
+ instruction->set_sharding(
+ HloSharding::FromProto(sharding.value()).ValueOrDie());
+ }
+ if (predecessors) {
+ for (auto* pre : *predecessors) {
+ Status status = pre->AddControlDependencyTo(instruction);
+ if (!status.ok()) {
+ return TokenError(StrCat("error adding control dependency for: ", name,
+ " status: ", status.ToString()));
+ }
}
+ }
+ if (metadata) {
+ instruction->set_metadata(*metadata);
+ }
+ return AddInstruction(name, instruction);
+}
- if (attribute_name == "sharding") {
- // Parse "sharding=".
- if (has_sharding) {
- return TokenError("expects at most 1 'sharding='");
- }
- has_sharding = true;
- if (!ParseSharding(instruction)) {
- return false;
- }
- } else if (attribute_name == "control-predecessors") {
- // Parse "control-predecessors"
- if (has_control) {
- return TokenError("expects at most 1 'control-predecessors='");
- }
- has_control = true;
- if (!ParseControlPredecessors(instruction)) {
+// ::= '{' (single_sharding | tuple_sharding) '}'
+//
+// tuple_sharding ::= single_sharding* (',' single_sharding)*
+bool HloParser::ParseSharding(OpSharding* sharding) {
+ // A single sharding starts with '{' and is not followed by '{'.
+ // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
+ // an empty tuple.
+ if (!ParseToken(TokKind::kLbrace,
+ "expected '{' to start sharding attribute")) {
+ return false;
+ }
+
+ if (lexer_.GetKind() != TokKind::kLbrace &&
+ lexer_.GetKind() != TokKind::kRbrace) {
+ return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
+ }
+
+ // Tuple sharding.
+ // Allow empty tuple shardings.
+ if (lexer_.GetKind() != TokKind::kRbrace) {
+ do {
+ if (!ParseSingleSharding(sharding->add_tuple_shardings(),
+ /*lbrace_pre_lexed=*/false)) {
return false;
}
- } else {
- return TokenError(StrCat("unexpected attribute: ", attribute_name));
- }
+ } while (EatIfPresent(TokKind::kComma));
}
+ sharding->set_type(OpSharding::Type::OpSharding_Type_TUPLE);
- return AddInstruction(name, instruction);
+ return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
}
-// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('['
-// dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list
-bool HloParser::ParseSharding(HloInstruction* instruction) {
- if (!ParseToken(TokKind::kLbrace,
+// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
+// ('devices=' ('[' dims ']')* device_list)? '}'
+// dims ::= int_list device_list ::= int_list
+bool HloParser::ParseSingleSharding(OpSharding* sharding,
+ bool lbrace_pre_lexed) {
+ if (!lbrace_pre_lexed &&
+ !ParseToken(TokKind::kLbrace,
"expected '{' to start sharding attribute")) {
return false;
}
@@ -545,7 +901,6 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
}
}
- OpSharding sharding;
if (replicated) {
if (!devices.empty()) {
return TokenError(
@@ -555,7 +910,7 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
return TokenError(
"replicated shardings should not have any tile shape set");
}
- sharding.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
+ sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
} else if (maximal) {
if (devices.size() != 1) {
return TokenError(
@@ -564,8 +919,8 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
if (!ShapeUtil::Equal(tile_shape, Shape())) {
return TokenError("maximal shardings should not have any tile shape set");
}
- sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
- sharding.add_tile_assignment_devices(devices[0]);
+ sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
+ sharding->add_tile_assignment_devices(devices[0]);
} else {
if (devices.size() <= 1) {
return TokenError(
@@ -579,47 +934,43 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
"non-maximal shardings must have a tile assignment list including "
"dimensions");
}
- sharding.set_type(OpSharding::Type::OpSharding_Type_OTHER);
- *sharding.mutable_tile_shape() = tile_shape;
+ sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER);
+ *sharding->mutable_tile_shape() = tile_shape;
for (int64 dim : tile_assignment_dimensions) {
- sharding.add_tile_assignment_dimensions(dim);
+ sharding->add_tile_assignment_dimensions(dim);
}
for (int64 device : devices) {
- sharding.add_tile_assignment_devices(device);
+ sharding->add_tile_assignment_devices(device);
}
}
- instruction->set_sharding(HloSharding::FromProto(sharding).ValueOrDie());
lexer_.Lex();
return true;
}
// '{' name+ '}'
-bool HloParser::ParseControlPredecessors(HloInstruction* instruction) {
+bool HloParser::ParseInstructionNames(
+ std::vector<HloInstruction*>* instructions) {
if (!ParseToken(TokKind::kLbrace,
- "expects '{' at the beginning of control predecessors")) {
+ "expects '{' at the beginning of instruction name list")) {
return false;
}
do {
string name;
if (!ParseName(&name)) {
- return TokenError("expects a control predecessor");
+ return TokenError("expects a instruction name");
}
- HloInstruction* pre =
+ HloInstruction* instr =
tensorflow::gtl::FindPtrOrNull(instruction_pool_, name);
- if (!pre) {
+ if (!instr) {
return TokenError(
- StrCat("control predecessor ", name, " is not defined: "));
- }
- Status status = pre->AddControlDependencyTo(instruction);
- if (!status.ok()) {
- return TokenError(StrCat("error adding control dependency for: ", name,
- " status: ", status.ToString()));
+ Printf("instruction '%s' is not defined", name.c_str()));
}
+ instructions->push_back(instr);
} while (EatIfPresent(TokKind::kComma));
return ParseToken(TokKind::kRbrace,
- "expects '}' at the end of control predecessors");
+ "expects '}' at the end of control instructions");
}
bool HloParser::SetValueInLiteral(int64 value, int64 linear_index,
@@ -957,28 +1308,199 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
return true;
}
-// extra_attribute ::= ',' attribute_name value
-template <typename T>
-bool HloParser::ParseExtraAttribute(T* value,
- const string& expected_attribute) {
- if (!ParseToken(TokKind::kComma,
- "expects ',' in front of an extra attribute")) {
+// sub_attributes ::= '{' (','? attribute)* '}'
+bool HloParser::ParseSubAttributes(
+ const std::unordered_map<string, AttrConfig>& attrs) {
+ if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) {
return false;
}
- string attribute_name;
- if (!ParseAttributeName(&attribute_name) &&
- attribute_name != expected_attribute) {
- return TokenError(StrCat("expects attribute name: ", expected_attribute));
+ std::unordered_set<string> seen_attrs;
+ if (lexer_.GetKind() == TokKind::kRbrace) {
+ // empty
+ } else {
+ do {
+ EatIfPresent(TokKind::kComma);
+ if (!ParseAttributeHelper(attrs, &seen_attrs)) {
+ return false;
+ }
+ } while (lexer_.GetKind() != TokKind::kRbrace);
+ }
+ // Check that all required attrs were seen.
+ for (const auto& attr_it : attrs) {
+ if (attr_it.second.required &&
+ seen_attrs.find(attr_it.first) == seen_attrs.end()) {
+ return TokenError(Printf("sub-attribute %s is expected but not seen",
+ attr_it.first.c_str()));
+ }
}
- if (!ParseAttributeValue(value)) {
- return TokenError(
- StrCat("expects value for attribute: ", expected_attribute));
+ return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
+}
+
+// attributes ::= (',' attribute)*
+bool HloParser::ParseAttributes(
+ const std::unordered_map<string, AttrConfig>& attrs) {
+ std::unordered_set<string> seen_attrs;
+ while (EatIfPresent(TokKind::kComma)) {
+ if (!ParseAttributeHelper(attrs, &seen_attrs)) {
+ return false;
+ }
+ }
+ // Check that all required attrs were seen.
+ for (const auto& attr_it : attrs) {
+ if (attr_it.second.required &&
+ seen_attrs.find(attr_it.first) == seen_attrs.end()) {
+ return TokenError(Printf("attribute %s is expected but not seen",
+ attr_it.first.c_str()));
+ }
}
return true;
}
-template <>
-bool HloParser::ParseAttributeValue<HloComputation*>(HloComputation** value) {
+bool HloParser::ParseAttributeHelper(
+ const std::unordered_map<string, AttrConfig>& attrs,
+ std::unordered_set<string>* seen_attrs) {
+ string name;
+ if (!ParseAttributeName(&name)) {
+ return TokenError("error parsing attributes");
+ }
+ VLOG(1) << "Parsing attribute " << name;
+ if (!seen_attrs->insert(name).second) {
+ return TokenError(Printf("attribute %s already exists", name.c_str()));
+ }
+ auto attr_it = attrs.find(name);
+ if (attr_it == attrs.end()) {
+ return TokenError(Printf("unexpected attribute %s", name.c_str()));
+ }
+ AttrTy attr_type = attr_it->second.attr_type;
+ void* attr_out_ptr = attr_it->second.result;
+ bool success = [&] {
+ switch (attr_type) {
+ case AttrTy::kInt64: {
+ int64 result;
+ if (!ParseInt64(&result)) {
+ return false;
+ }
+ static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
+ return true;
+ }
+ case AttrTy::kInt32: {
+ int64 result;
+ if (!ParseInt64(&result)) {
+ return false;
+ }
+ if (result != static_cast<int32>(result)) {
+ return TokenError("value out of range for int32");
+ }
+ static_cast<optional<int32>*>(attr_out_ptr)
+ ->emplace(static_cast<int32>(result));
+ return true;
+ }
+ case AttrTy::kFloat: {
+ double result;
+ if (!ParseDouble(&result)) {
+ return false;
+ }
+ if (result > std::numeric_limits<float>::max() ||
+ result < std::numeric_limits<float>::lowest()) {
+ return TokenError("value out of range for float");
+ }
+ static_cast<optional<float>*>(attr_out_ptr)
+ ->emplace(static_cast<float>(result));
+ return true;
+ }
+ case AttrTy::kHloComputation: {
+ HloComputation* result;
+ if (!ParseComputationName(&result)) {
+ return false;
+ }
+ static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
+ return true;
+ }
+ case AttrTy::kWindow: {
+ Window result;
+ if (!ParseWindow(&result)) {
+ return false;
+ }
+ static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
+ return true;
+ }
+ case AttrTy::kConvolutionDimensionNumbers: {
+ ConvolutionDimensionNumbers result;
+ if (!ParseConvolutionDimensionNumbers(&result)) {
+ return false;
+ }
+ static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
+ ->emplace(result);
+ return true;
+ }
+ case AttrTy::kSharding: {
+ OpSharding sharding;
+ if (!ParseSharding(&sharding)) {
+ return false;
+ }
+ static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
+ return true;
+ }
+ case AttrTy::kInstructionList: {
+ std::vector<HloInstruction*> result;
+ if (!ParseInstructionNames(&result)) {
+ return false;
+ }
+ static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
+ ->emplace(result);
+ return true;
+ }
+ case AttrTy::kBracedInt64List: {
+ std::vector<int64> result;
+ if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
+ &result)) {
+ return false;
+ }
+ static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
+ ->emplace(result);
+ return true;
+ }
+ case AttrTy::kSliceRanges: {
+ SliceRanges result;
+ if (!ParseSliceRanges(&result)) {
+ return false;
+ }
+ static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
+ return true;
+ }
+ case AttrTy::kPaddingConfig: {
+ PaddingConfig result;
+ if (!ParsePaddingConfig(&result)) {
+ return false;
+ }
+ static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
+ return true;
+ }
+ case AttrTy::kString: {
+ string result;
+ if (!ParseString(&result)) {
+ return false;
+ }
+ static_cast<optional<string>*>(attr_out_ptr)->emplace(result);
+ return true;
+ }
+ case AttrTy::kMetadata: {
+ OpMetadata result;
+ if (!ParseMetadata(&result)) {
+ return false;
+ }
+ static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
+ return true;
+ }
+ }
+ }();
+ if (!success) {
+ return TokenError(Printf("error parsing attribute %s", name.c_str()));
+ }
+ return true;
+}
+
+bool HloParser::ParseComputationName(HloComputation** value) {
string name;
if (!ParseName(&name)) {
return TokenError("expects computation name");
@@ -990,9 +1512,269 @@ bool HloParser::ParseAttributeValue<HloComputation*>(HloComputation** value) {
return true;
}
-template <>
-bool HloParser::ParseAttributeValue<int64>(int64* value) {
- return ParseInt64(value);
+// ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
+// The subattributes can appear in any order. 'size=' is required, others are
+// optional.
+bool HloParser::ParseWindow(Window* window) {
+ if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
+ return false;
+ }
+
+ std::vector<int64> size;
+ std::vector<int64> stride;
+ std::vector<std::vector<int64>> pad;
+ std::vector<int64> lhs_dilate;
+ std::vector<int64> rhs_dilate;
+ while (lexer_.GetKind() != TokKind::kRbrace) {
+ string field_name;
+ if (!ParseAttributeName(&field_name)) {
+ return TokenError("expects sub-attributes in window");
+ }
+ bool ok = [&] {
+ if (field_name == "size") {
+ return ParseDxD("size", &size);
+ }
+ if (field_name == "stride") {
+ return ParseDxD("stride", &stride);
+ }
+ if (field_name == "lhs_dilate") {
+ return ParseDxD("lhs_dilate", &lhs_dilate);
+ }
+ if (field_name == "rhs_dilate") {
+ return ParseDxD("rls_dilate", &rhs_dilate);
+ }
+ if (field_name == "pad") {
+ return ParseWindowPad(&pad);
+ }
+ return TokenError(StrCat("unexpected attribute name: ", field_name));
+ }();
+ if (!ok) {
+ return false;
+ }
+ }
+
+ if (size.empty()) {
+ return TokenError(
+ "sub-attribute 'size=' is required in the window attribute");
+ }
+ if (!stride.empty() && stride.size() != size.size()) {
+ return TokenError("expects 'stride=' has the same size as 'size='");
+ }
+ if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
+ return TokenError("expects 'lhs_dilate=' has the same size as 'size='");
+ }
+ if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
+ return TokenError("expects 'rhs_dilate=' has the same size as 'size='");
+ }
+ if (!pad.empty() && pad.size() != size.size()) {
+ return TokenError("expects 'pad=' has the same size as 'size='");
+ }
+
+ for (int i = 0; i < size.size(); i++) {
+ window->add_dimensions()->set_size(size[i]);
+ if (!pad.empty()) {
+ window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
+ window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
+ }
+ // If some field is not present, it has the default value.
+ window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
+ window->mutable_dimensions(i)->set_base_dilation(
+ lhs_dilate.empty() ? 1 : lhs_dilate[i]);
+ window->mutable_dimensions(i)->set_window_dilation(
+ rhs_dilate.empty() ? 1 : rhs_dilate[i]);
+ }
+ return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
+}
+
+// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
+// The string looks like "dim_labels=0bf_0io->0bf".
+bool HloParser::ParseConvolutionDimensionNumbers(
+ ConvolutionDimensionNumbers* dnums) {
+ if (lexer_.GetKind() != TokKind::kDimLabels) {
+ return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
+ }
+ string str = lexer_.GetStrVal();
+
+ // The str is expected to have 3 items, lhs, rhs, out, and it must looks like
+ // lhs_rhs->out, that is, the first separator is "_" and the second is "->".
+ // So we replace the "->" with "_" and then split on "_".
+ str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->",
+ /*newsub=*/"_",
+ /*replace_all=*/false);
+ std::vector<string> lhs_rhs_out = Split(str, "_");
+ if (lhs_rhs_out.size() != 3) {
+ LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
+ << str;
+ }
+
+ const int64 rank = lhs_rhs_out[0].length();
+ if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) {
+ return TokenError(
+ "convolution lhs, rhs, and output must have the same rank");
+ }
+ if (rank < 3) {
+ return TokenError("convolution rank must >=3");
+ }
+
+ auto is_unique = [](string str) -> bool {
+ std::sort(str.begin(), str.end());
+ return std::unique(str.begin(), str.end()) == str.end();
+ };
+
+ // lhs
+ {
+ const string& lhs = lhs_rhs_out[0];
+ if (!is_unique(lhs)) {
+ return TokenError(
+ StrCat("expects unique lhs dimension numbers, but sees ", lhs));
+ }
+ for (int i = 0; i < rank - 2; i++) {
+ dnums->add_spatial_dimensions(-1);
+ }
+ for (int i = 0; i < rank; i++) {
+ char c = lhs[i];
+ if (c == 'b') {
+ dnums->set_input_batch_dimension(i);
+ } else if (c == 'f') {
+ dnums->set_input_feature_dimension(i);
+ } else if (c < '0' + rank && c >= '0') {
+ dnums->set_spatial_dimensions(c - '0', i);
+ } else {
+ return TokenError(
+ Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1));
+ }
+ }
+ }
+ // rhs
+ {
+ const string& rhs = lhs_rhs_out[1];
+ if (!is_unique(rhs)) {
+ return TokenError(
+ StrCat("expects unique rhs dimension numbers, but sees ", rhs));
+ }
+ for (int i = 0; i < rank - 2; i++) {
+ dnums->add_kernel_spatial_dimensions(-1);
+ }
+ for (int i = 0; i < rank; i++) {
+ char c = rhs[i];
+ if (c == 'i') {
+ dnums->set_kernel_input_feature_dimension(i);
+ } else if (c == 'o') {
+ dnums->set_kernel_output_feature_dimension(i);
+ } else if (c < '0' + rank && c >= '0') {
+ dnums->set_kernel_spatial_dimensions(c - '0', i);
+ } else {
+ return TokenError(
+ Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1));
+ }
+ }
+ }
+ // output
+ {
+ const string& out = lhs_rhs_out[2];
+ if (!is_unique(out)) {
+ return TokenError(
+ StrCat("expects unique output dimension numbers, but sees ", out));
+ }
+ for (int i = 0; i < rank; i++) {
+ char c = out[i];
+ if (c == 'b') {
+ dnums->set_output_batch_dimension(i);
+ } else if (c == 'f') {
+ dnums->set_output_feature_dimension(i);
+ } else if (c < '0' + rank && c >= '0') {
+ if (dnums->spatial_dimensions(c - '0') != i) {
+ return TokenError(
+ "output spatial dimensions should be the same as input spatial "
+ "dimensions");
+ }
+ } else {
+ return TokenError(
+ Printf("expects [0-%lldbf] in output dimension numbers", rank - 1));
+ }
+ }
+ }
+
+ lexer_.Lex();
+ return true;
+}
+
+// ::= '{' ranges '}'
+// ::= /*empty*/
+// ::= range (',' range)*
+// range ::= '[' start ':' limit (':' stride)? ']'
+//
+// The slice ranges are printed as:
+//
+// {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
+//
+// This function extracts the starts, limits, and strides as 3 vectors to the
+// result. If stride is not present, stride is 1. For example, if the slice
+// ranges is printed as:
+//
+// {[2:3:4], [5:6:7], [8:9]}
+//
+// The the parsed result will be:
+//
+// {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
+//
+bool HloParser::ParseSliceRanges(SliceRanges* result) {
+ if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
+ return false;
+ }
+ std::vector<std::vector<int64>> ranges;
+ if (lexer_.GetKind() == TokKind::kRbrace) {
+ // empty
+ return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
+ }
+ do {
+ ranges.emplace_back();
+ if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
+ &ranges.back())) {
+ return false;
+ }
+ } while (EatIfPresent(TokKind::kComma));
+
+ for (const auto& range : ranges) {
+ if (range.size() != 2 && range.size() != 3) {
+ return TokenError(Printf(
+ "expects [start:limit:step] or [start:limit], but sees %ld elements.",
+ range.size()));
+ }
+ }
+
+ for (const auto& range : ranges) {
+ result->starts.push_back(range[0]);
+ result->limits.push_back(range[1]);
+ result->strides.push_back(range.size() == 3 ? range[2] : 1);
+ }
+ return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
+}
+
+// int64list ::= start int64_elements end
+// int64_elements
+// ::= /*empty*/
+// ::= int64_val (delim int64_val)*
+bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
+ const TokKind delim,
+ std::vector<int64>* result) {
+ if (!ParseToken(start, StrCat("expects an int64 list starting with ",
+ TokKindToString(start)))) {
+ return false;
+ }
+ if (lexer_.GetKind() == end) {
+ // empty
+ } else {
+ do {
+ int64 i;
+ if (!ParseInt64(&i)) {
+ return false;
+ }
+ result->push_back(i);
+ } while (EatIfPresent(delim));
+ }
+ return ParseToken(
+ end, StrCat("expects an int64 list to end with ", TokKindToString(end)));
}
// param_list ::= '(' param_list1 ')'
@@ -1070,6 +1852,121 @@ bool HloParser::ParseAttributeName(string* result) {
return true;
}
+bool HloParser::ParseString(string* result) {
+ VLOG(1) << "ParseString";
+ if (lexer_.GetKind() != TokKind::kString) {
+ return TokenError("expects string");
+ }
+ *result = lexer_.GetStrVal();
+ lexer_.Lex();
+ return true;
+}
+
+bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
+ if (!result->empty()) {
+ return TokenError(
+ Printf("sub-attribute '%s=' already exists", name.c_str()));
+ }
+ // 1D
+ if (lexer_.GetKind() == TokKind::kInt) {
+ int64 number;
+ if (!ParseInt64(&number)) {
+ return TokenError(Printf("expects sub-attribute '%s=i'", name.c_str()));
+ }
+ result->push_back(number);
+ return true;
+ }
+ // 2D or higher.
+ if (lexer_.GetKind() == TokKind::kDxD) {
+ string str = lexer_.GetStrVal();
+ if (!SplitAndParseAsInts(str, 'x', result)) {
+ return TokenError(
+ Printf("expects sub-attribute '%s=ixj...'", name.c_str()));
+ }
+ lexer_.Lex();
+ return true;
+ }
+ return TokenError("expects token type kInt or kDxD");
+}
+
+bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
+ if (!pad->empty()) {
+ return TokenError("sub-attribute 'pad=' already exists");
+ }
+ if (lexer_.GetKind() != TokKind::kPad) {
+ return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
+ }
+ string str = lexer_.GetStrVal();
+ std::vector<string> padding_str = Split(str, 'x');
+ for (int i = 0; i < padding_str.size(); i++) {
+ std::vector<int64> low_high;
+ if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) ||
+ low_high.size() != 2) {
+ return TokenError(
+ "expects padding_low and padding_high separated by '_'");
+ }
+ pad->push_back(low_high);
+ }
+ lexer_.Lex();
+ return true;
+}
+
+// This is the inverse xla::ToString(PaddingConfig). The padding config string
+// looks like "0_0_0x3_3_1". The string is first separated by 'x', each
+// substring represents one PaddingConfigDimension. The substring is 3 (or 2)
+// numbers joined by '_'.
+bool HloParser::ParsePaddingConfig(PaddingConfig* padding) {
+ if (lexer_.GetKind() != TokKind::kPad) {
+ return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
+ }
+ string str = lexer_.GetStrVal();
+ std::vector<string> padding_str = Split(str, 'x');
+ for (const auto& padding_dim_str : padding_str) {
+ std::vector<int64> padding_dim;
+ if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) ||
+ (padding_dim.size() != 2 && padding_dim.size() != 3)) {
+ return TokenError(
+ "expects padding config pattern like 'low_high_interior' or "
+ "'low_high'");
+ }
+ auto* dim = padding->add_dimensions();
+ dim->set_edge_padding_low(padding_dim[0]);
+ dim->set_edge_padding_high(padding_dim[1]);
+ dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
+ }
+ lexer_.Lex();
+ return true;
+}
+
+// '{' metadata_string '}'
+bool HloParser::ParseMetadata(OpMetadata* metadata) {
+ std::unordered_map<string, AttrConfig> attrs;
+ optional<string> op_type;
+ optional<string> op_name;
+ optional<string> source_file;
+ optional<int32> source_line;
+ attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
+ attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
+ attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
+ attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line};
+ if (!ParseSubAttributes(attrs)) {
+ return false;
+ }
+ if (op_type) {
+ metadata->set_op_type(*op_type);
+ }
+ if (op_name) {
+ metadata->set_op_name(*op_name);
+ }
+ if (source_file) {
+ metadata->set_source_file(*source_file);
+ }
+ if (source_line) {
+ metadata->set_source_line(*source_line);
+ }
+ return true;
+}
+
bool HloParser::ParseOpcode(HloOpcode* result) {
VLOG(1) << "ParseOpcode";
if (lexer_.GetKind() != TokKind::kOpcode) {
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index 359256f064..bed912d921 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -25,6 +25,7 @@ namespace tools {
namespace {
using tensorflow::StringPiece;
+using tensorflow::strings::StrCat;
struct TestData {
string test_name;
@@ -35,6 +36,10 @@ string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
return data.param.test_name;
}
+// For each string below, we check that:
+// - we parse it to an HloModule successfully, and
+// - the stringification of the resulting HloModule is equal to our original
+// string.
std::vector<TestData> CreateTestCases() {
// clang-format off
return std::vector<TestData>({
@@ -43,10 +48,11 @@ std::vector<TestData> CreateTestCases() {
"AxpyParam",
R"(HloModule axpy_module:
-ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
- %alpha = f32[2,4]{1,0} parameter(0)
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
%x = f32[2,4]{1,0} parameter(1)
- %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
%y = f32[2,4]{1,0} parameter(2)
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
}
@@ -59,7 +65,7 @@ ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
R"(HloModule constant_pred_module:
ENTRY %constant_pred () -> pred[] {
- ROOT %constant = pred[] constant(true)
+ ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}
}
)"
@@ -77,7 +83,8 @@ ENTRY %constant_s32 () -> s32[] {
},
// f32 constant, but the value is not a decimal
{
-"ConstantF32", R"(HloModule ConstantF32_module:
+"ConstantF32",
+R"(HloModule ConstantF32_module:
ENTRY %ConstantF32.v4 () -> f32[] {
ROOT %constant = f32[] constant(42)
@@ -151,7 +158,7 @@ ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f3
%v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
%v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
%greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated}
- ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2)
+ ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
}
)"
@@ -181,6 +188,19 @@ ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f
)"
},
+{
+"ShardedTupleCreate",
+R"(HloModule ShardedTupleCreate_module:
+
+ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
+ %v1 = f32[] parameter(0)
+ %v2 = f32[3]{0} parameter(1)
+ %v3 = f32[2,3]{1,0} parameter(2)
+ ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}}
+}
+
+)"
+},
// int32 result = 0;
// while (result < 5) { result = result + 1; }
{
@@ -212,9 +232,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
R"(HloModule TwoSendRecvBothWayRecvFist_module:
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
- %recv = f32[] recv(), channel_id=15, sharding={maximal device=1}
- ROOT %constant = f32[] constant(2.1), sharding={maximal device=0}
- %send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
+ %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1}
+ ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1}
+ %constant = f32[] constant(2.1), sharding={maximal device=0}
+ %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
+ %send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0}
}
)"
@@ -248,6 +270,277 @@ ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] {
}
)"
+},
+// reduce window
+{
+"ReduceWindow",
+R"(HloModule R4UnitWindow_module:
+
+%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
+}
+
+ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] {
+ %operand = f32[13,12,8,15]{0,3,2,1} parameter(0)
+ %constant = f32[] constant(0)
+ ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3
+}
+
+)"
+},
+// convolution
+{
+"Convolution",
+R"(HloModule Convolve1D1Window_0_module:
+
+ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
+ %input = f32[1,2,1]{2,1,0} parameter(0)
+ %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
+ %filter = f32[1,1,1]{2,1,0} parameter(1)
+ ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f
+}
+
+)"
+},
+// reverse(constant)
+{
+"Reverse4D",
+R"(HloModule Reverse4DFloatArrayOnDim01_module:
+
+ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] {
+ %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } })
+ ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1}
+}
+
+)"
+},
+// concat
+{
+"Concat",
+R"(HloModule Concat2x3With2x5_module:
+
+ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] {
+ %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } })
+ %constant.1 = f32[2,5]{1,0} constant(f32[2,5] { { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } })
+ ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1}
+}
+
+)"
+},
+// map
+{
+"Map",
+R"(HloModule MapBinaryAdder_module:
+
+%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
+}
+
+ENTRY %MapBinaryAdder.v3 (param0: f32[4], param1: f32[4]) -> f32[4] {
+ %param0 = f32[4]{0} parameter(0)
+ %param1 = f32[4]{0} parameter(1)
+ ROOT %map = f32[4]{0} map(f32[4]{0} %param0, f32[4]{0} %param1), to_apply=%add_F32.v3
+}
+
+)"
+},
+// reduce
+{
+"Reduce",
+R"(HloModule ReduceR3ToR2_module:
+
+%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
+}
+
+ENTRY %ReduceR3ToR2.v3 (input: f32[8,16,256]) -> f32[8,16] {
+ %input = f32[8,16,256]{2,1,0} parameter(0)
+ %constant = f32[] constant(0)
+ ROOT %reduce = f32[8,16]{1,0} reduce(f32[8,16,256]{2,1,0} %input, f32[] %constant), dimensions={2}, to_apply=%add_F32.v3
+}
+
+)"
+},
+// select and scatter
+{
+"SelectAndScatter",
+R"(HloModule R4F32OverlapSmall_module:
+
+%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs)
+}
+
+%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
+ %lhs.1 = f32[] parameter(0)
+ %rhs.1 = f32[] parameter(1)
+ ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
+}
+
+ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] {
+ %constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } })
+ %constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } })
+ %constant.2 = f32[] constant(0)
+ ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3
+}
+
+)"
+},
+// slice
+{
+"Slice",
+R"(HloModule slice_module:
+
+ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
+ %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
+ ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3:1], [0:3:1], [0:4:2], [0:4:1]}
+}
+
+)"
+},
+// slice, no stride
+{
+"SliceNoStride",
+R"(HloModule Slice3x3x3_To_1x3x3_F32_module:
+
+ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] {
+ %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } })
+ ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]}
+}
+
+)"
+},
+// slice R0
+{
+"SliceR0",
+R"(HloModule SliceR0_module:
+
+ENTRY %SliceR0.v2 () -> s32[] {
+ %constant = s32[] constant(1)
+ ROOT %slice = s32[] slice(s32[] %constant), slice={}
+}
+
+)"
+},
+// transpose
+{
+"Transpose",
+R"(HloModule Transpose_module:
+
+ENTRY %Transpose.v2 () -> s32[1,2,3] {
+ %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } })
+ ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2}
+}
+
+)"
+},
+// Dynamic slice
+{
+"DynamicSlice",
+R"(HloModule DynamicSlice_module:
+
+ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] {
+ %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
+ %constant = s32[1]{0} constant({0})
+ %start_index = s32[1]{0} parameter(1)
+ %concatenate = s32[3]{0} concatenate(s32[1]{0} %constant, s32[1]{0} %constant, s32[1]{0} %start_index), dimensions={0}
+ ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258}
+}
+
+)"
+},
+// Dynamic update slice
+{
+"DynamicUpdateSlice",
+R"(HloModule DynamicUpdateSlice_module:
+
+ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] {
+ %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
+ %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
+ %start_indices = s32[4]{0} parameter(2)
+ ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices)
+}
+
+)"
+},
+// batch norm training
+{
+"BatchNormTraining",
+R"(HloModule BasicTraining_module:
+
+ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) {
+ %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ {1, 2} }, { /*i1=1*/ {3, 4} } }, { /*i0=1*/ { /*i1=0*/ {5, 6} }, { /*i1=1*/ {7, 8} } } })
+ %constant.1 = f32[2]{0} constant({2, 3})
+ %constant.2 = f32[2]{0} constant({1, 2})
+ ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3
+}
+
+)"
+},
+// batch norm inference
+{
+"BatchNormInference",
+R"(HloModule BatchNormInference_module:
+
+ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] {
+ %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
+ %offset = f32[2]{0} parameter(1)
+ %scale = f32[2]{0} parameter(2)
+ %mean = f32[2]{0} parameter(3)
+ %variance = f32[2]{0} parameter(4)
+ ROOT %batch-norm-inference = f32[2,2,2,2]{3,2,1,0} batch-norm-inference(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance), epsilon=0.001, feature_index=0
+}
+
+)"
+},
+// batch norm grad
+{
+"BatchNormGrad",
+R"(HloModule BatchNormGrad_module:
+
+ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) {
+ %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
+ %scale = f32[2]{0} parameter(1)
+ %mean = f32[2]{0} parameter(2)
+ %variance = f32[2]{0} parameter(3)
+ %grad_output = f32[2,2,2,2]{3,2,1,0} parameter(4)
+ ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0
+}
+
+)"
+},
+// pad
+{
+"Pad",
+R"(HloModule Pad1DS3Array_module:
+
+ENTRY %Pad1DS3Array.v3 () -> f32[8] {
+ %constant = f32[3]{0} constant({1, 2, 3})
+ %constant.1 = f32[] constant(0.1)
+ ROOT %pad = f32[8]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1
+}
+
+)"
+},
+// pad has interior
+{
+"PadHasInterior",
+R"(HloModule PadHasInterior_module:
+
+ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] {
+ %input = f32[1,25,7,7]{3,2,1,0} parameter(0)
+ %constant = f32[] constant(-5.123)
+ ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0
+}
+
+)"
}
});
// clang-format on
@@ -261,7 +554,10 @@ class HloParserTest : public ::testing::Test,
<< "'" << s << "' does not contain '" << expected << "'";
}
- void ExpectSuccess() {
+ // Expects "ToString(Parse(string)) == string", that is, parses the string,
+ // asserts that it succeeded, stringifies the parsed module, and checks that
+ // the it equals the original string.
+ void ExpectEqual() {
const string& original = GetParam().module_string;
auto result = Parse(original);
TF_EXPECT_OK(result.status());
@@ -270,7 +566,7 @@ class HloParserTest : public ::testing::Test,
}
};
-TEST_P(HloParserTest, Run) { ExpectSuccess(); }
+TEST_P(HloParserTest, Run) { ExpectEqual(); }
INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest,
::testing::ValuesIn(CreateTestCases()),
@@ -427,6 +723,136 @@ ENTRY %ConstantWithExp.v4 () -> f32[] {
// printed as "300".
}
+TEST_F(HloParserTest, AttibutesAnyOrder) {
+ const string original = R"(HloModule any_order_module:
+
+ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
+ %input = f32[1,2,1]{2,1,0} parameter(0)
+ %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
+ %filter = f32[1,1,1]{2,1,0} parameter(1)
+ ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
+}
+
+)";
+ TF_EXPECT_OK(Parse(original).status());
+}
+
+TEST_F(HloParserTest, InvalidDimLabels) {
+ string prefix = R"(HloModule invalid_dim_labels_module:
+
+ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
+ %input = f32[1,2,1]{2,1,0} parameter(0)
+ %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
+ %filter = f32[1,1,1]{2,1,0} parameter(1)
+ ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )";
+ string suffix = R"(
+}
+
+)";
+
+ ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=00_01_10", suffix))
+ .status()
+ .error_message(),
+ "expects dim labels pattern");
+
+ ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=010_1100->010", suffix))
+ .status()
+ .error_message(),
+ "must have the same rank");
+
+ ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=0bf_io0->b0f", suffix))
+ .status()
+ .error_message(),
+ "output spatial dimensions should be the same as input "
+ "spatial dimensions");
+}
+
+TEST_F(HloParserTest, UnexpectedAttribute) {
+ const string original = R"(HloModule unexpected_attr_module:
+
+ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
+ %recv = (f32[], u32[]) recv(), channel_id=15
+ %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
+ ROOT %constant = f32[] constant(2.1)
+ %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv
+ %send-done = () send-done((f32[], u32[]) %send), channel_id=16
+}
+
+)";
+ ExpectHasSubstr(Parse(original).status().error_message(),
+ "unexpected attribute calls");
+}
+
+TEST_F(HloParserTest, MissingAttribute) {
+ const string original = R"(HloModule missing_attr_module:
+
+ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
+ %recv = (f32[], u32[]) recv(), channel_id=15
+ %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
+ ROOT %constant = f32[] constant(-2.1)
+ %send = (f32[], u32[]) send(f32[] %constant)
+ %send-done = () send-done((f32[], u32[]) %send), channel_id=16
+}
+
+)";
+ ExpectHasSubstr(Parse(original).status().error_message(),
+ "attribute channel_id is expected but not seen");
+}
+
+TEST_F(HloParserTest, PredecessorUndefined) {
+ const string original = R"(HloModule pre_not_found_module:
+
+ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
+ %recv = (f32[], u32[]) recv(), channel_id=15
+ %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
+ ROOT %constant = f32[] constant(2.1)
+ %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done}
+ %send-done = () send-done((f32[], u32[]) %send), channel_id=16
+}
+
+)";
+ ExpectHasSubstr(Parse(original).status().error_message(),
+ "'done' is not defined");
+}
+
+TEST_F(HloParserTest, SliceAllowOmitStride1) {
+ const string original = R"(HloModule slice_module:
+
+ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
+ %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
+ ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3], [0:3], [0:4:2], [0:4]}
+}
+
+)";
+ TF_EXPECT_OK(Parse(original).status());
+}
+
+TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
+ const string original = R"(HloModule window_pad_module:
+
+ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
+ %input = f32[1,2,1]{2,1,0} parameter(0)
+ %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
+ %filter = f32[1,1,1]{2,1,0} parameter(1)
+ ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), dim_labels=b0f_0io->b0f, window={pad=1_1_0 size=1}
+}
+
+)";
+ ExpectHasSubstr(Parse(original).status().error_message(),
+ "expects padding_low and padding_high separated by '_'");
+}
+
+TEST_F(HloParserTest, CommaBetweenSubAttributes) {
+ const string original = R"(HloModule test_comma_module:
+
+ENTRY %test_comma.v4 () -> f32[] {
+ ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"}
+}
+
+)";
+ TF_EXPECT_OK(Parse(original).status());
+}
+
} // namespace
} // namespace tools
} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h
index 9c2069e756..78a72837ca 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_token.h
+++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h
@@ -57,6 +57,10 @@ enum class TokKind {
// Typed tokens.
kName, // %foo
kAttributeName, // dimensions=
+ kDimLabels, // [0-9bf]+_[0-9io]+->[0-9bf]+
+ kDxD, // [0-9]+(x[0-9]+)+
+ kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*
+ kString, // "abcd\"\n"
kShape, // f32[2,3]{1,0}
kOpcode, // add
kInt, // 42
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index 89b26b8916..503e7d456e 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h
index 3b19ca321c..9fa4297523 100644
--- a/tensorflow/compiler/xla/types.h
+++ b/tensorflow/compiler/xla/types.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/platform/types.h"
#include <Eigen/Core>
@@ -32,6 +33,8 @@ using ::tensorflow::int16;
using ::tensorflow::int32;
using ::tensorflow::int64;
+using ::tensorflow::bfloat16;
+
using ::tensorflow::uint8;
using ::tensorflow::uint16;
using ::tensorflow::uint32;
diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc
index 23161873a0..6f7f1479b9 100644
--- a/tensorflow/compiler/xla/window_util.cc
+++ b/tensorflow/compiler/xla/window_util.cc
@@ -26,8 +26,8 @@ namespace xla {
namespace window_util {
/* static */ string ToString(const WindowDimension& dim) {
- using tensorflow::strings::StrCat;
using tensorflow::strings::StrAppend;
+ using tensorflow::strings::StrCat;
string str = StrCat("(size=", dim.size());
if (dim.stride() != 1) {
StrAppend(&str, ",stride=", dim.stride());
@@ -49,22 +49,22 @@ namespace window_util {
}
string ToString(const Window& window) {
- using tensorflow::strings::StrCat;
using tensorflow::strings::StrAppend;
+ using tensorflow::strings::StrCat;
string str;
- const auto add_field = [&](
- const char* heading,
- std::function<string(const WindowDimension&)> format) {
- StrAppend(&str, heading, "=");
- const char* prefix = "";
- for (const auto& window_dimension : window.dimensions()) {
- StrAppend(&str, prefix, format(window_dimension));
- prefix = "x";
- }
- };
-
- add_field("window",
+ const auto add_field =
+ [&](const char* heading,
+ std::function<string(const WindowDimension&)> format) {
+ StrAppend(&str, heading, "=");
+ const char* prefix = "";
+ for (const auto& window_dimension : window.dimensions()) {
+ StrAppend(&str, prefix, format(window_dimension));
+ prefix = "x";
+ }
+ };
+
+ add_field("size",
[](const WindowDimension& dim) { return StrCat(dim.size()); });
if (HasStride(window)) {
add_field(" stride",
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 06987e0044..eac8f2ff07 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -46,6 +46,12 @@ enum PrimitiveType {
// converted to f16 from f32 at arbirary points in the computation.
F16 = 10;
F32 = 11;
+
+ // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
+ // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
+ // and 7 bits for the mantissa.
+ BF16 = 16;
+
F64 = 12;
// Complex values of fixed width.
@@ -63,6 +69,8 @@ enum PrimitiveType {
// An opaque type used for passing context specific data to a custom
// operation.
OPAQUE = 14;
+
+ // Next = 17
}
// Describes the value held inside padding elements.
@@ -310,7 +318,10 @@ message LiteralProto {
repeated double f64s = 9;
repeated float c64s = 12; // Stored as interleaved real, imag floats.
repeated LiteralProto tuple_literals = 10;
- bytes f16s = 11; // Note: the F16s are encoded in little endian byte order
+ // The F16s and BF16s are encoded in little endian byte order
+ bytes f16s = 11;
+ bytes bf16s = 13;
+ // Next = 14
}
message WindowDimension {
@@ -825,8 +836,10 @@ message OpSharding {
REPLICATED = 0;
// This sharding is maximal - one device runs the entire operation.
MAXIMAL = 1;
- // Neither of the above; tile_shape and tile_assignment are both used.
- OTHER = 2;
+ // This sharding is a tuple - only the tuple_shardings field is valid.
+ TUPLE = 2;
+ // None of the above; tile_shape and tile_assignment are both used.
+ OTHER = 3;
}
Type type = 1;
// The shape of the sharded tile.
@@ -838,6 +851,13 @@ message OpSharding {
// Flattened list of device IDs. The order of flattening is the same as used
// by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
repeated int64 tile_assignment_devices = 4;
+ // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
+ // in pre-order. The tuple shape could be nested; here we store just a
+ // flattened list of all leaves in the tuple shape. Note that the tuple shape
+ // is not stored here; shardings do not store the shapes to which they are
+ // applied, this is inferred from the instruction this sharding gets attached
+ // to.
+ repeated OpSharding tuple_shardings = 5;
}
message OpRequest {
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 3d53cbba56..b7ade95115 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -51,6 +51,7 @@ py_library(
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/contrib/linear_optimizer:sdca_estimator_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
+ "//tensorflow/contrib/lite/python:lite",
"//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/contrib/losses:losses_py",
"//tensorflow/contrib/losses:metric_learning_py",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index 3068e9ed8f..1eda1abfcf 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -79,6 +79,7 @@ from tensorflow.contrib import tpu
from tensorflow.contrib import training
from tensorflow.contrib import util
from tensorflow.contrib.eager.python import tfe as eager
+from tensorflow.contrib.lite.python import lite
from tensorflow.contrib.ndlstm import python as ndlstm
from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph
from tensorflow.contrib.specs import python as specs
diff --git a/tensorflow/contrib/android/asset_manager_filesystem.cc b/tensorflow/contrib/android/asset_manager_filesystem.cc
index 9e4d3290c3..380a652435 100644
--- a/tensorflow/contrib/android/asset_manager_filesystem.cc
+++ b/tensorflow/contrib/android/asset_manager_filesystem.cc
@@ -97,7 +97,7 @@ class RandomAccessFileFromAsset : public RandomAccessFile {
off64_t new_offset = AAsset_seek64(asset.get(), offset, SEEK_SET);
off64_t length = AAsset_getLength64(asset.get());
if (new_offset < 0) {
- result->set(scratch, 0);
+ *result = StringPiece(scratch, 0);
return errors::OutOfRange("Read after file end.");
}
const off64_t region_left =
@@ -106,7 +106,7 @@ class RandomAccessFileFromAsset : public RandomAccessFile {
if (read < 0) {
return errors::Internal("Error reading from asset.");
}
- result->set(scratch, region_left);
+ *result = StringPiece(scratch, region_left);
return (region_left == to_read)
? Status::OK()
: errors::OutOfRange("Read less bytes than requested.");
diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h
index 5e316538ce..70037d5bd8 100644
--- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h
+++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h
@@ -33,9 +33,9 @@ template <typename ValueType, typename WeightType,
class WeightedQuantilesBuffer {
public:
struct BufferEntry {
- BufferEntry(const ValueType& v, const WeightType& w)
- : value(v), weight(w) {}
- BufferEntry() : value(0), weight(0) {}
+ BufferEntry(ValueType v, WeightType w)
+ : value(std::move(v)), weight(std::move(w)) {}
+ BufferEntry() : value(), weight(0) {}
bool operator<(const BufferEntry& other) const {
return kCompFn(value, other.value);
@@ -67,7 +67,7 @@ class WeightedQuantilesBuffer {
// Push entry to buffer and maintain a compact representation within
// pre-defined size limit.
- void PushEntry(const ValueType& value, const WeightType& weight) {
+ void PushEntry(ValueType value, WeightType weight) {
// Callers are expected to act on a full compacted buffer after the
// PushEntry call returns.
QCHECK(!IsFull()) << "Buffer already full: " << max_size_;
@@ -78,7 +78,7 @@ class WeightedQuantilesBuffer {
}
// Push back the entry to the buffer.
- vec_.push_back(BufferEntry(value, weight));
+ vec_.push_back(BufferEntry(std::move(value), std::move(weight)));
}
// Returns a sorted vector view of the base buffer and clears the buffer.
diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
index 82b8e8c1c2..d66f645f62 100644
--- a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
@@ -36,7 +36,7 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) {
c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim,
reduce_dim ? learner_config.num_classes() - 1
: learner_config.num_classes())});
- c->set_output(1, {c->Vector(InferenceContext::kUnknownDim)});
+ c->set_output(1, {c->UnknownShape()});
return Status::OK();
}
diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake
index f3882e8cf7..c6a15f2ca0 100644
--- a/tensorflow/contrib/cmake/tf_c.cmake
+++ b/tensorflow/contrib/cmake/tf_c.cmake
@@ -21,7 +21,6 @@ set(tf_c_srcs
"${tensorflow_source_dir}/tensorflow/c/c_api_function.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.h"
- "${tensorflow_source_dir}/tensorflow/c/eager/tape.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/tape.h"
"${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/runtime.h"
@@ -47,4 +46,5 @@ add_dependencies(
tf_c_python_api
tf_c
tf_core_lib
+ tf_core_framework
tf_protos_cc)
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index c3dc8531bb..c607546f4a 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -301,6 +301,8 @@ file(GLOB_RECURSE tf_core_framework_srcs
"${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc"
"${tensorflow_source_dir}/tensorflow/core/common_runtime/session_factory.cc"
"${tensorflow_source_dir}/tensorflow/core/common_runtime/session_options.cc"
+ "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.cc"
+ "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.h"
"${tensorflow_source_dir}/public/*.h"
)
@@ -314,6 +316,7 @@ file(GLOB_RECURSE tf_core_framework_exclude_srcs
"${tensorflow_source_dir}/tensorflow/core/util/*test*.h"
"${tensorflow_source_dir}/tensorflow/core/util/*test*.cc"
"${tensorflow_source_dir}/tensorflow/core/util/*main.cc"
+ "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*test*.cc"
)
list(REMOVE_ITEM tf_core_framework_srcs ${tf_core_framework_exclude_srcs})
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index a2ab4b9ae4..b1102cecbe 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -70,7 +70,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc"
- "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc"
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index 03c168795c..4a61ed7a35 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -81,7 +81,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t
GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc")
-GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(data_prefetching "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc")
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 43b98659e3..3e2a858b8a 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -499,6 +499,19 @@ add_python_module("tensorflow/contrib/linear_optimizer/kernels/g3doc")
add_python_module("tensorflow/contrib/linear_optimizer/python")
add_python_module("tensorflow/contrib/linear_optimizer/python/kernel_tests")
add_python_module("tensorflow/contrib/linear_optimizer/python/ops")
+add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD
+ COMMAND ${CMAKE_COMMAND} -E make_directory
+ "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite")
+add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD
+ COMMAND ${CMAKE_COMMAND} -E make_directory
+ "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python")
+add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD
+ COMMAND ${CMAKE_COMMAND} -E touch
+ "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/__init__.py")
+add_custom_command(
+ TARGET tf_python_copy_scripts_to_destination PRE_BUILD
+ COMMAND ${CMAKE_COMMAND} -E touch
+ ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/lite.py)
add_python_module("tensorflow/contrib/lookup")
add_python_module("tensorflow/contrib/losses")
add_python_module("tensorflow/contrib/losses/python")
@@ -780,8 +793,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py)
-GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops"
- DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_data_prefetching_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_prefetching_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops"
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index 7bcf5a5f4d..eaede0e00e 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -35,19 +35,8 @@ tf_custom_op_library(
],
)
-# TODO(mrry): Move the kernels out of the core library into this library.
-tf_custom_op_library(
- name = "_dataset_ops.so",
- srcs = [
- "ops/dataset_ops.cc",
- ],
-)
-
tf_gen_op_libs(
- op_lib_names = [
- "dataset_ops",
- "prefetching_ops",
- ],
+ op_lib_names = ["prefetching_ops"],
)
filegroup(
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 0c7e793689..6e43ae0e63 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -23,6 +23,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
@@TextLineDataset
@@batch_and_drop_remainder
+@@padded_batch_and_drop_remainder
@@dense_to_sparse_batch
@@enumerate_dataset
@@group_by_window
@@ -41,10 +42,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
# pylint: disable=unused-import
+
from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder
from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch
+from tensorflow.contrib.data.python.ops.batching import padded_batch_and_drop_remainder
from tensorflow.contrib.data.python.ops.batching import unbatch
from tensorflow.contrib.data.python.ops.dataset_ops import Dataset
from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
deleted file mode 100644
index 1574384cb2..0000000000
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ /dev/null
@@ -1,232 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_def_builder.h"
-#include "tensorflow/core/framework/shape_inference.h"
-
-namespace tensorflow {
-
-// --------------------------------------------------------------------------
-
-// The ops in this section can be composed to define an input
-// pipeline. Each op produces a DT_VARIANT tensor that represents
-// a DAG of "dataset" objects. An "dataset" object can be converted
-// to a stateful "iterator" by passing the "dataset" to the
-// "MakeIterator" op.
-//
-// TODO(b/65524810): DT_VARIANT tensors that represent "dataset" objects are
-// not presently serializable. To avoid issues with constant folding, ensure
-// that any "source dataset" ops (i.e. ops that output a dataset and do not
-// take one as input) are marked "stateful".
-
-REGISTER_OP("IgnoreErrorsDataset")
- .Input("input_dataset: variant")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that contains the elements of `input_dataset` ignoring errors.
-)doc");
-
-REGISTER_OP("MapAndBatchDataset")
- .Input("input_dataset: variant")
- .Input("other_arguments: Targuments")
- .Input("batch_size: int64")
- .Input("num_parallel_batches: int64")
- .Output("handle: variant")
- .Attr("f: func")
- .Attr("Targuments: list(type) >= 0")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that applies `f` to the outputs of `input_dataset` and then
-batches `batch_size` of them.
-
-Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
-to `batch_size * num_parallel_batches` copies of `f` in parallel.
-
-batch_size: A scalar representing the number of elements to accumulate in a
- batch. It determines the number of concurrent invocations of `f` that process
- elements from `input_dataset` in parallel.
-num_parallel_batches: A scalar representing the number of batches to create in
- parallel. Processing multiple batches in parallel benefits workloads prone to
- stragglers.
-)doc");
-
-REGISTER_OP("ScanDataset")
- .Input("input_dataset: variant")
- .Input("initial_state: Tstate")
- .Input("other_arguments: Targuments")
- .Output("handle: variant")
- .Attr("f: func")
- .Attr("Tstate: list(type) >= 1")
- .Attr("Targuments: list(type) >= 0")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset successively reduces `f` over the elements of `input_dataset`.
-)doc");
-
-REGISTER_OP("ParallelInterleaveDataset")
- .Input("input_dataset: variant")
- .Input("other_arguments: Targuments")
- .Input("cycle_length: int64")
- .Input("block_length: int64")
- .Input("sloppy: bool")
- .Output("handle: variant")
- .Attr("f: func")
- .Attr("Targuments: list(type) >= 0")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that applies `f` to the outputs of `input_dataset`.
-
-The resulting dataset is similar to the `InterleaveDataset`, with the exception
-that if retrieving the next value from a dataset would cause the requester to
-block, it will skip that input dataset. This dataset is especially useful
-when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
-allows the training step to proceed so long as some data is available.
-
-!! WARNING !! This dataset is not deterministic!
-
-f: A function mapping elements of `input_dataset`, concatenated with
- `other_arguments`, to a Dataset variant that contains elements matching
- `output_types` and `output_shapes`.
-)doc");
-
-REGISTER_OP("GroupByWindowDataset")
- .Input("input_dataset: variant")
- .Input("key_func_other_arguments: Tkey_func_other_arguments")
- .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
- .Input(
- "window_size_func_other_arguments: Twindow_size_func_other_arguments")
- .Output("handle: variant")
- .Attr("key_func: func")
- .Attr("reduce_func: func")
- .Attr("window_size_func: func")
- .Attr("Tkey_func_other_arguments: list(type) >= 0")
- .Attr("Treduce_func_other_arguments: list(type) >= 0")
- .Attr("Twindow_size_func_other_arguments: list(type) >= 0")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that computes a windowed group-by on `input_dataset`.
-
-// TODO(mrry): Support non-int64 keys.
-
-key_func: A function mapping an element of `input_dataset`, concatenated
- with `key_func_other_arguments` to a scalar value of type DT_INT64.
-)doc");
-
-REGISTER_OP("DenseToSparseBatchDataset")
- .Input("input_dataset: variant")
- .Input("batch_size: int64")
- .Input("row_shape: int64")
- .Output("handle: variant")
- // NOTE(mrry): the 0th and 2nd elements will be DT_INT64.
- .Attr("output_types: list(type) >= 1")
- // NOTE(mrry): the 1st and 2nd elements will be vectors.
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that yields a SparseTensor for each element of the input.
-
-input_dataset: A handle to an input dataset. Must have a single component.
-batch_size: A scalar representing the number of elements to accumulate in a
- batch.
-row_shape: A vector representing the dense shape of each row in the produced
- SparseTensor. The shape may be partially specified, using `-1` to indicate
- that a particular dimension should use the maximum size of all batch elements.
-)doc");
-
-REGISTER_OP("SqlDataset")
- .Input("driver_name: string")
- .Input("data_source_name: string")
- .Input("query: string")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
- // stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that executes a SQL query and emits rows of the result set.
-
-driver_name: The database type. Currently, the only supported type is 'sqlite'.
-data_source_name: A connection string to connect to the database.
-query: A SQL query to execute.
-)doc");
-
-REGISTER_OP("DatasetToSingleElement")
- .Input("dataset: variant")
- .Output("components: output_types")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- std::vector<PartialTensorShape> output_shapes;
- TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
- if (output_shapes.size() != c->num_outputs()) {
- return errors::InvalidArgument(
- "`output_shapes` must be the same length as `output_types` (",
- output_shapes.size(), " vs. ", c->num_outputs());
- }
- for (size_t i = 0; i < output_shapes.size(); ++i) {
- shape_inference::ShapeHandle output_shape_handle;
- TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
- output_shapes[i], &output_shape_handle));
- c->set_output(static_cast<int>(i), output_shape_handle);
- }
- return Status::OK();
- })
- .Doc(R"doc(
-Outputs the single element from the given dataset.
-
-dataset: A handle to a dataset that contains a single element.
-components: The components of the single element of `input`.
-)doc");
-
-REGISTER_OP("SerializeIterator")
- .Input("resource_handle: resource")
- .Output("serialized: variant")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Converts the given `resource_handle` representing an iterator to a variant tensor.
-
-resource_handle: A handle to an iterator resource.
-serialized: A variant tensor storing the state of the iterator contained in the
- resource.
-)doc");
-
-REGISTER_OP("DeserializeIterator")
- .Input("resource_handle: resource")
- .Input("serialized: variant")
- .SetShapeFn(shape_inference::NoOutputs)
- .Doc(R"doc(
-Converts the given variant tensor to an iterator and stores it in the given resource.
-
-resource_handle: A handle to an iterator resource.
-serialized: A variant tensor storing the state of the iterator contained in the
- resource.
-)doc");
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 5877f42dcf..6723f92e08 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -366,6 +366,7 @@ py_test(
srcs = ["sequence_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -429,6 +430,7 @@ py_test(
srcs = ["zip_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/python:array_ops",
@@ -449,7 +451,7 @@ py_test(
srcs_version = "PY2AND3",
tags = [
"manual",
- "no_oss",
+ "no_oss", # b/68785503
],
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 670f622c3c..951d4bb5f7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -52,8 +52,9 @@ class BatchDatasetTest(test.TestCase):
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
- iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
- .repeat(count).batch(batch_size).make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
+ .repeat(count).batch(batch_size).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -69,7 +70,7 @@ class BatchDatasetTest(test.TestCase):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(14):
- self.assertAllEqual(component[(i*14 + j) % 7]**2,
+ self.assertAllEqual(component[(i * 14 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -84,12 +85,12 @@ class BatchDatasetTest(test.TestCase):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(8):
- self.assertAllEqual(component[(i*8 + j) % 7]**2,
+ self.assertAllEqual(component[(i * 8 + j) % 7]**2,
result_component[j])
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range((14 * 7) % 8):
- self.assertAllEqual(component[((num_batches - 1)*8 + j) % 7]**2,
+ self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -107,10 +108,10 @@ class BatchDatasetTest(test.TestCase):
seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
- iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens)
- .map(lambda x: array_ops.fill([x], x)).padded_batch(
- 4,
- padded_shapes=padded_shape).make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(seq_lens)
+ .map(lambda x: array_ops.fill([x], x)).padded_batch(
+ 4, padded_shapes=padded_shape).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -118,35 +119,40 @@ class BatchDatasetTest(test.TestCase):
with self.test_session() as sess:
# Test with random sequence lengths, and max padding.
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
- sess.run(init_op, feed_dict={padded_shape: [-1],
- seq_lens: random_seq_lens})
+ sess.run(
+ init_op, feed_dict={
+ padded_shape: [-1],
+ seq_lens: random_seq_lens
+ })
for i in range(8):
result = sess.run(get_next)
padded_len = np.max(result)
self.assertEqual((4, padded_len), result.shape)
for j in range(4):
- seq_len = random_seq_lens[(i*4)+j]
+ seq_len = random_seq_lens[(i * 4) + j]
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test with random sequence lengths, and constant padding.
- sess.run(init_op, feed_dict={padded_shape: [25],
- seq_lens: random_seq_lens})
+ sess.run(
+ init_op, feed_dict={
+ padded_shape: [25],
+ seq_lens: random_seq_lens
+ })
for i in range(8):
result = sess.run(get_next)
self.assertEqual((4, 25), result.shape)
for j in range(4):
- seq_len = random_seq_lens[(i*4)+j]
+ seq_len = random_seq_lens[(i * 4) + j]
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test correct handling of empty tensors.
- sess.run(init_op, feed_dict={padded_shape: [-1],
- seq_lens: [0, 0, 0, 0]})
+ sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]})
result = sess.run(get_next)
self.assertAllEqual([[], [], [], []], result)
with self.assertRaises(errors.OutOfRangeError):
@@ -154,8 +160,7 @@ class BatchDatasetTest(test.TestCase):
# Test error handling with constant sequence lengths, and
# too-short padding.
- sess.run(init_op, feed_dict={padded_shape: [5],
- seq_lens: [6, 5, 5, 5]})
+ sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]})
with self.assertRaises(errors.DataLossError):
result = sess.run(get_next)
@@ -166,11 +171,13 @@ class BatchDatasetTest(test.TestCase):
def fill_tuple(x):
filled = array_ops.fill([x], x)
return (filled, string_ops.as_string(filled))
- iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple)
- .padded_batch(
- 4,
- padded_shapes=(padded_shape, padded_shape),
- padding_values=(-1, "<end>")).make_initializable_iterator())
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple)
+ .padded_batch(
+ 4,
+ padded_shapes=(padded_shape, padded_shape),
+ padding_values=(-1, "<end>")).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -178,15 +185,18 @@ class BatchDatasetTest(test.TestCase):
with self.test_session() as sess:
# Test with random sequence lengths, and max padding.
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
- sess.run(init_op, feed_dict={padded_shape: [-1],
- seq_lens: random_seq_lens})
+ sess.run(
+ init_op, feed_dict={
+ padded_shape: [-1],
+ seq_lens: random_seq_lens
+ })
for i in range(8):
result = sess.run(get_next)
padded_len = np.max(result[0])
self.assertEqual((4, padded_len), result[0].shape)
self.assertEqual((4, padded_len), result[1].shape)
for j in range(4):
- seq_len = random_seq_lens[(i*4)+j]
+ seq_len = random_seq_lens[(i * 4) + j]
self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[0][j, seq_len:],
[-1] * (padded_len - seq_len))
@@ -220,20 +230,21 @@ class BatchDatasetTest(test.TestCase):
constant_op.constant([-1, -1], dtype=dtypes.int64),
constant_op.constant([37], dtype=dtypes.int64)))
- for dataset in [dynamic_padding_from_tensor_shapes,
- dynamic_padding_from_lists,
- dynamic_padding_from_lists_with_minus_one,
- dynamic_padding_from_tensors]:
+ for dataset in [
+ dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists,
+ dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors
+ ]:
self.assertEqual([None, None], dataset.output_shapes[0].as_list())
self.assertEqual([None, None, None], dataset.output_shapes[1].as_list())
self.assertEqual([None, 37], dataset.output_shapes[2].as_list())
def testDenseToSparseBatchDataset(self):
components = np.random.randint(12, size=(100,)).astype(np.int32)
- iterator = (dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([x], x)).apply(
- batching.dense_to_sparse_batch(4, [12]))
- .make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([x], x)).apply(
+ batching.dense_to_sparse_batch(4,
+ [12])).make_initializable_iterator())
init_op = iterator.initializer
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
@@ -242,24 +253,26 @@ class BatchDatasetTest(test.TestCase):
for start in range(0, len(components), 4):
results = sess.run(get_next)
+ self.assertAllEqual([[i, j]
+ for i, c in enumerate(components[start:start + 4])
+ for j in range(c)], results.indices)
self.assertAllEqual(
- [[i, j] for i, c in enumerate(components[start:start+4])
- for j in range(c)], results.indices)
- self.assertAllEqual(
- [c for c in components[start:start+4] for _ in range(c)],
+ [c for c in components[start:start + 4] for _ in range(c)],
results.values)
- self.assertAllEqual(
- [min(4, len(components) - start), 12], results.dense_shape)
+ self.assertAllEqual([min(4,
+ len(components) - start), 12],
+ results.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testDenseToSparseBatchDatasetWithUnknownShape(self):
components = np.random.randint(5, size=(40,)).astype(np.int32)
- iterator = (dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([x, x], x)).apply(
- batching.dense_to_sparse_batch(
- 4, [5, -1])).make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([x, x], x)).apply(
+ batching.dense_to_sparse_batch(
+ 4, [5, -1])).make_initializable_iterator())
init_op = iterator.initializer
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
@@ -268,27 +281,30 @@ class BatchDatasetTest(test.TestCase):
for start in range(0, len(components), 4):
results = sess.run(get_next)
- self.assertAllEqual(
- [[i, j, z] for i, c in enumerate(components[start:start+4])
- for j in range(c) for z in range(c)], results.indices)
- self.assertAllEqual(
- [c for c in components[start:start+4]
- for _ in range(c) for _ in range(c)],
- results.values)
- self.assertAllEqual(
- [min(4, len(components) - start),
- 5,
- np.max(components[start:start+4])],
- results.dense_shape)
+ self.assertAllEqual([[i, j, z]
+ for i, c in enumerate(components[start:start + 4])
+ for j in range(c)
+ for z in range(c)], results.indices)
+ self.assertAllEqual([
+ c
+ for c in components[start:start + 4] for _ in range(c)
+ for _ in range(c)
+ ], results.values)
+ self.assertAllEqual([
+ min(4,
+ len(components) - start), 5,
+ np.max(components[start:start + 4])
+ ], results.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testDenseToSparseBatchDatasetWithInvalidShape(self):
input_tensor = array_ops.constant([[1]])
- iterator = (dataset_ops.Dataset.from_tensors(input_tensor)
- .apply(batching.dense_to_sparse_batch(4, [-2]))
- .make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ batching.dense_to_sparse_batch(4, [-2]))
+ .make_initializable_iterator())
init_op = iterator.initializer
with self.test_session() as sess:
@@ -298,8 +314,10 @@ class BatchDatasetTest(test.TestCase):
def testDenseToSparseBatchDatasetShapeErrors(self):
input_tensor = array_ops.placeholder(dtypes.int32)
- iterator = (dataset_ops.Dataset.from_tensors(input_tensor).apply(
- batching.dense_to_sparse_batch(4, [12])).make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ batching.dense_to_sparse_batch(4,
+ [12])).make_initializable_iterator())
init_op = iterator.initializer
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
@@ -356,8 +374,7 @@ class BatchDatasetTest(test.TestCase):
def testUnbatchMultiElementTupleDataset(self):
data = tuple([(math_ops.range(10 * i, 10 * i + 10),
- array_ops.fill([10], "hi"))
- for i in range(3)])
+ array_ops.fill([10], "hi")) for i in range(3)])
data = dataset_ops.Dataset.from_tensor_slices(data)
expected_types = ((dtypes.int32, dtypes.string),) * 3
data = data.batch(2)
@@ -370,9 +387,7 @@ class BatchDatasetTest(test.TestCase):
with self.test_session() as sess:
for i in range(10):
- self.assertEqual(((i, b"hi"),
- (10 + i, b"hi"),
- (20 + i, b"hi")),
+ self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
@@ -385,9 +400,10 @@ class BatchDatasetTest(test.TestCase):
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = (dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(batch_size))
- .make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).apply(
+ batching.batch_and_drop_remainder(batch_size))
+ .make_initializable_iterator())
next_element = iterator.get_next()
@@ -404,14 +420,51 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
+ def testPaddedBatchAndDropRemainder(self):
+ els = []
+ for length in [3, 6, 9, 4, 12, 10, 2]:
+ els.append((np.array(length), np.arange(length) + 1,
+ np.array(length * 2)))
+
+ dataset = dataset_ops.Dataset.from_tensors(els[0])
+ for el in els[1:]:
+ dataset = dataset.concatenate(dataset_ops.Dataset.from_tensors(el))
+
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+ iterator = (
+ dataset.apply(
+ batching.padded_batch_and_drop_remainder(
+ batch_size, ([], [None], []))).make_initializable_iterator())
+
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ for test_batch_size in [1, 3, 7, 10]:
+ sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
+ num_batches = 7 // test_batch_size
+ for i in range(num_batches):
+ result = sess.run(next_element)
+ for component_idx, result_component in enumerate(result):
+ for j in range(test_batch_size):
+ data_idx = i * test_batch_size + j
+ comp = result_component[j]
+ unpadded = comp[comp > 0]
+ if np.isscalar(comp):
+ # The boolean mask indexing above adds a dim back. Rm it.
+ unpadded = unpadded[0]
+ self.assertAllEqual(els[data_idx][component_idx], unpadded)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
def testBatchAndDropRemainderShapeInference(self):
- components = (array_ops.placeholder(dtypes.int32), (array_ops.placeholder(
- dtypes.int32, shape=[None]), array_ops.placeholder(
- dtypes.int32, shape=[20, 30])))
+ components = (array_ops.placeholder(dtypes.int32),
+ (array_ops.placeholder(dtypes.int32, shape=[None]),
+ array_ops.placeholder(dtypes.int32, shape=[20, 30])))
# Test with a statically known batch size.
- dataset = (dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(128)))
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(components).apply(
+ batching.batch_and_drop_remainder(128)))
self.assertIs(None, dataset.output_shapes[0].ndims)
self.assertEqual([128], dataset.output_shapes[1][0].as_list())
@@ -420,8 +473,9 @@ class BatchDatasetTest(test.TestCase):
# Test with a dynamic batch size: the static shape will be unknown, because
# `batch_size` is a placeholder.
batch_size = array_ops.placeholder(dtypes.int64)
- dataset = (dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(batch_size)))
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(components).apply(
+ batching.batch_and_drop_remainder(batch_size)))
self.assertIs(None, dataset.output_shapes[0].ndims)
self.assertEqual([None], dataset.output_shapes[1][0].as_list())
@@ -441,9 +495,10 @@ class BatchDatasetTest(test.TestCase):
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
- iterator = (dataset_ops.Dataset.from_tensor_slices(components).repeat(count)
- .apply(batching.map_and_batch(_map_fn, batch_size))
- .make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
+ batching.map_and_batch(_map_fn, batch_size))
+ .make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -459,7 +514,7 @@ class BatchDatasetTest(test.TestCase):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(14):
- self.assertAllEqual(component[(i*14 + j) % 7]**2,
+ self.assertAllEqual(component[(i * 14 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -474,7 +529,7 @@ class BatchDatasetTest(test.TestCase):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(8):
- self.assertAllEqual(component[(i*8 + j) % 7]**2,
+ self.assertAllEqual(component[(i * 8 + j) % 7]**2,
result_component[j])
# The last batch should fail with `OutOfRange`.
with self.assertRaises(errors.OutOfRangeError):
@@ -495,8 +550,9 @@ class BatchDatasetTest(test.TestCase):
array_ops.check_numerics(
constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = (dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
- .make_initializable_iterator())
+ iterator = (
+ dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
+ .make_initializable_iterator())
init_op = iterator.initializer
with self.test_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
@@ -504,6 +560,7 @@ class BatchDatasetTest(test.TestCase):
def testBatchAndMapDatasetShapeMismatch(self):
"""Test a dataset that maps a TF function across its input elements."""
+
def generator():
yield [1]
yield [2]
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
index df9147af6c..07fecf04fa 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
@@ -32,7 +32,7 @@ from tensorflow.python.util import nest
class DatasetSerializationTestBase(test.TestCase):
- """Base class for testing finite serializable datasets."""
+ """Base class for testing serializable datasets."""
def tearDown(self):
self._delete_ckpt()
@@ -58,17 +58,19 @@ class DatasetSerializationTestBase(test.TestCase):
if ds_fn2:
self.verify_restore_in_modified_graph(ds_fn1, ds_fn2, num_outputs)
- def verify_unused_iterator(self, ds_fn, num_outputs):
+ def verify_unused_iterator(self, ds_fn, num_outputs, verify_exhausted=True):
"""Verifies that saving and restoring an unused iterator works.
Args:
ds_fn: See `run_core_tests`.
num_outputs: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
"""
- self.verify_run_with_breaks(ds_fn, [0], num_outputs)
+ self.verify_run_with_breaks(
+ ds_fn, [0], num_outputs, verify_exhausted=verify_exhausted)
def verify_fully_used_iterator(self, ds_fn, num_outputs):
"""Verifies that saving and restoring a fully used iterator works.
@@ -104,12 +106,16 @@ class DatasetSerializationTestBase(test.TestCase):
ds_fn, [], 0, ckpt_saved=True, verify_exhausted=True)
self.assertEqual(len(actual), 0)
- def verify_init_before_restore(self, ds_fn, num_outputs):
+ def verify_init_before_restore(self,
+ ds_fn,
+ num_outputs,
+ verify_exhausted=True):
"""Verifies that retoring into an already initilized iterator works.
Args:
ds_fn: See `run_core_tests`.
num_outputs: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
@@ -118,9 +124,14 @@ class DatasetSerializationTestBase(test.TestCase):
ds_fn,
self.gen_break_points(num_outputs),
num_outputs,
- init_before_restore=True)
+ init_before_restore=True,
+ verify_exhausted=verify_exhausted)
- def verify_multiple_breaks(self, ds_fn, num_outputs, num_breaks=10):
+ def verify_multiple_breaks(self,
+ ds_fn,
+ num_outputs,
+ num_breaks=10,
+ verify_exhausted=True):
"""Attempts to save/restore at multiple break points.
Args:
@@ -128,16 +139,22 @@ class DatasetSerializationTestBase(test.TestCase):
num_outputs: See `run_core_tests`.
num_breaks: The number of break points. These are uniformly spread in
[0, num_outputs] both inclusive.
+ verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
"""
- self.verify_run_with_breaks(ds_fn,
- self.gen_break_points(num_outputs, num_breaks),
- num_outputs)
+ self.verify_run_with_breaks(
+ ds_fn,
+ self.gen_break_points(num_outputs),
+ num_outputs,
+ verify_exhausted=verify_exhausted)
- def verify_reset_restored_iterator(self, ds_fn, num_outputs,
- break_point=None):
+ def verify_reset_restored_iterator(self,
+ ds_fn,
+ num_outputs,
+ break_point=None,
+ verify_exhausted=True):
"""Attempts to re-initialize a restored iterator.
This is useful when restoring a training checkpoint during validation.
@@ -146,6 +163,7 @@ class DatasetSerializationTestBase(test.TestCase):
ds_fn: See `run_core_tests`.
num_outputs: See `run_core_tests`.
break_point: Break point. Optional. Defaults to num_outputs/2.
+ verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
@@ -153,7 +171,8 @@ class DatasetSerializationTestBase(test.TestCase):
break_point = num_outputs // 2 if not break_point else break_point
# Collect ground truth containing all outputs.
- expected = self.gen_outputs(ds_fn, [], num_outputs, verify_exhausted=True)
+ expected = self.gen_outputs(
+ ds_fn, [], num_outputs, verify_exhausted=verify_exhausted)
# Skip some items and save checkpoint.
self.gen_outputs(ds_fn, [], break_point, verify_exhausted=False)
@@ -168,15 +187,17 @@ class DatasetSerializationTestBase(test.TestCase):
sess.run(init_op)
for _ in range(num_outputs):
actual.append(sess.run(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
self.match(expected, actual)
def verify_restore_in_modified_graph(self,
ds_fn1,
ds_fn2,
num_outputs,
- break_point=None):
+ break_point=None,
+ verify_exhausted=True):
"""Attempts to restore an iterator in a modified graph.
Builds an input pipeline using ds_fn1, runs it for `break_point` steps
@@ -188,6 +209,7 @@ class DatasetSerializationTestBase(test.TestCase):
ds_fn2: See `run_core_tests`.
num_outputs: See `run_core_tests`.
break_point: Break point. Optional. Defaults to num_outputs/2.
+ verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
@@ -196,15 +218,15 @@ class DatasetSerializationTestBase(test.TestCase):
# Skip `break_point` items and store the remaining produced from ds_fn1
# in `expected`.
- self.gen_outputs(ds_fn1, [], break_point)
+ self.gen_outputs(ds_fn1, [], break_point, verify_exhausted=False)
expected = self.gen_outputs(
ds_fn1, [],
num_outputs - break_point,
ckpt_saved=True,
- verify_exhausted=True)
+ verify_exhausted=verify_exhausted)
# Generate `break_point` items from ds_fn1 and save checkpoint.
- self.gen_outputs(ds_fn1, [], break_point)
+ self.gen_outputs(ds_fn1, [], break_point, verify_exhausted=False)
actual = []
# Build graph for ds_fn2 but load checkpoint for ds_fn1.
@@ -214,8 +236,9 @@ class DatasetSerializationTestBase(test.TestCase):
self._restore(saver, sess)
for _ in range(num_outputs - break_point):
actual.append(sess.run(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
self.match(expected, actual)
@@ -223,6 +246,7 @@ class DatasetSerializationTestBase(test.TestCase):
ds_fn,
break_points,
num_outputs,
+ verify_exhausted=True,
init_before_restore=False):
"""Verifies that ds_fn() produces the same outputs with and without breaks.
@@ -237,6 +261,7 @@ class DatasetSerializationTestBase(test.TestCase):
ds_fn: See `gen_outputs`.
break_points: See `gen_outputs`.
num_outputs: See `gen_outputs`.
+ verify_exhausted: See `gen_outputs`.
init_before_restore: See `gen_outputs`.
Raises:
@@ -245,13 +270,13 @@ class DatasetSerializationTestBase(test.TestCase):
expected = self.gen_outputs(
ds_fn, [],
num_outputs,
- verify_exhausted=True,
+ verify_exhausted=verify_exhausted,
init_before_restore=init_before_restore)
actual = self.gen_outputs(
ds_fn,
break_points,
num_outputs,
- verify_exhausted=True,
+ verify_exhausted=verify_exhausted,
init_before_restore=init_before_restore)
self.match(expected, actual)
@@ -261,7 +286,7 @@ class DatasetSerializationTestBase(test.TestCase):
num_outputs,
ckpt_saved=False,
init_before_restore=False,
- verify_exhausted=False):
+ verify_exhausted=True):
"""Generates elements from input dataset while stopping at break points.
Produces `num_outputs` outputs and saves the state of the iterator in the
@@ -285,7 +310,7 @@ class DatasetSerializationTestBase(test.TestCase):
after producing `num_outputs` elements.
Returns:
- A list if `num_outputs` items.
+ A list of `num_outputs` items.
"""
outputs = []
@@ -312,11 +337,11 @@ class DatasetSerializationTestBase(test.TestCase):
num_iters = end - start
for _ in range(num_iters):
outputs.append(sess.run(get_next_op))
- self._save(sess, saver)
- ckpt_saved = True
if i == len(break_points) and verify_exhausted:
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
+ self._save(sess, saver)
+ ckpt_saved = True
return outputs
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 271d80a54b..bda9a2a4a3 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
@@ -21,7 +21,6 @@ import os
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
@@ -34,6 +33,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index 329dc80ba5..f59ac760dc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -21,7 +21,6 @@ import os
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import enumerate_ops
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -30,6 +29,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index 8033f1d388..3ae8f71d77 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -21,7 +21,6 @@ import gzip
import os
import zlib
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.example import example_pb2
@@ -34,6 +33,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py
index 91615e9f62..1a26da82e5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -207,5 +208,82 @@ class SequenceDatasetTest(test.TestCase):
sess.run(get_next)
+class SequenceDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_skip_dataset(self, count):
+ components = (np.arange(10),)
+ return dataset_ops.Dataset.from_tensor_slices(components).skip(count)
+
+ def testSkipFewerThanInputs(self):
+ count = 4
+ num_outputs = 10 - count
+ self.run_core_tests(lambda: self._build_skip_dataset(count),
+ lambda: self._build_skip_dataset(count + 2),
+ num_outputs)
+
+ def testSkipVarious(self):
+ # Skip more than inputs
+ self.run_core_tests(lambda: self._build_skip_dataset(20), None, 0)
+ # Skip exactly the input size
+ self.run_core_tests(lambda: self._build_skip_dataset(10), None, 0)
+ self.run_core_tests(lambda: self._build_skip_dataset(-1), None, 0)
+ # Skip nothing
+ self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10)
+
+ def _build_take_dataset(self, count):
+ components = (np.arange(10),)
+ return dataset_ops.Dataset.from_tensor_slices(components).take(count)
+
+ def testTakeFewerThanInputs(self):
+ count = 4
+ self.run_core_tests(
+ lambda: self._build_take_dataset(count),
+ lambda: self._build_take_dataset(count + 2),
+ count,
+ )
+
+ def testTakeVarious(self):
+ # Take more than inputs
+ self.run_core_tests(lambda: self._build_take_dataset(20), None, 10)
+ # Take exactly the input size
+ self.run_core_tests(lambda: self._build_take_dataset(10), None, 10)
+ # Take all
+ self.run_core_tests(lambda: self._build_take_dataset(-1), None, 10)
+ # Take nothing
+ self.run_core_tests(lambda: self._build_take_dataset(0), None, 0)
+
+ def _build_repeat_dataset(self, count, take_count=3):
+ components = (np.arange(10),)
+ return dataset_ops.Dataset.from_tensor_slices(components).take(
+ take_count).repeat(count)
+
+ def testFiniteRepeat(self):
+ count = 10
+ self.run_core_tests(lambda: self._build_repeat_dataset(count),
+ lambda: self._build_repeat_dataset(count + 2),
+ 3 * count)
+
+ def testEmptyRepeat(self):
+ self.run_core_tests(lambda: self._build_repeat_dataset(0), None, 0)
+
+ def testInfiniteRepeat(self):
+ self.verify_unused_iterator(
+ lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
+ self.verify_init_before_restore(
+ lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
+ self.verify_multiple_breaks(
+ lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
+ self.verify_reset_restored_iterator(
+ lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
+ self.verify_restore_in_modified_graph(
+ lambda: self._build_repeat_dataset(-1),
+ lambda: self._build_repeat_dataset(2),
+ 20,
+ verify_exhausted=False)
+ # Test repeat empty dataset
+ self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), None, 0)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py
index b0e7218301..5d34b0024c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -110,5 +111,31 @@ class ZipDatasetTest(test.TestCase):
sess.run(get_next)
+class ZipDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, arr):
+ components = [
+ np.tile(np.array([[1], [2], [3], [4]]), 20),
+ np.tile(np.array([[12], [13], [14], [15]]), 22),
+ np.array(arr)
+ ]
+ datasets = [
+ dataset_ops.Dataset.from_tensor_slices(component)
+ for component in components
+ ]
+ return dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
+
+ def testCore(self):
+ # Equal length components
+ arr = [37.0, 38.0, 39.0, 40.0]
+ num_outputs = len(arr)
+ self.run_core_tests(lambda: self._build_dataset(arr), None, num_outputs)
+ # Variable length components
+ diff_size_arr = [1.0, 2.0]
+ self.run_core_tests(lambda: self._build_dataset(diff_size_arr),
+ lambda: self._build_dataset(arr), 2)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 727c5d1c38..1b81cf5be9 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -12,6 +12,20 @@ load(
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
py_library(
+ name = "dataset_ops",
+ srcs = [
+ "dataset_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":transformation_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_library(
name = "iterator_ops",
srcs = [
"iterator_ops.py",
@@ -59,7 +73,6 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- ":gen_dataset_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dataset_ops_gen",
@@ -115,31 +128,6 @@ tf_custom_op_py_library(
],
)
-tf_gen_op_wrapper_py(
- name = "gen_dataset_ops",
- out = "gen_dataset_ops.py",
- deps = ["//tensorflow/contrib/data:dataset_ops_op_lib"],
-)
-
-tf_custom_op_py_library(
- name = "dataset_ops",
- srcs = ["dataset_ops.py"],
- dso = ["//tensorflow/contrib/data:_dataset_ops.so"],
- kernels = [
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":gen_dataset_ops",
- ":transformation_ops",
- "//tensorflow/contrib/util:util_py",
- "//tensorflow/python:platform",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index e6e5f716b6..d4ade7adfd 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -25,6 +24,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import math_ops
@@ -103,6 +103,42 @@ def unbatch():
return _apply_fn
+def filter_irregular_batches(batch_size):
+ """Transformation that filters out batches that are not of size batch_size."""
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ tensor_batch_size = ops.convert_to_tensor(
+ batch_size, dtype=dtypes.int64, name="batch_size")
+
+ flattened = _RestructuredDataset(dataset,
+ tuple(nest.flatten(dataset.output_types)))
+
+ def _predicate(*xs):
+ """Return `True` if this element is a full batch."""
+ # Extract the dynamic batch size from the first component of the flattened
+ # batched element.
+ first_component = xs[0]
+ first_component_batch_size = array_ops.shape(
+ first_component, out_type=dtypes.int64)[0]
+
+ return math_ops.equal(first_component_batch_size, tensor_batch_size)
+
+ filtered = flattened.filter(_predicate)
+
+ maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
+
+ def _set_first_dimension(shape):
+ return shape.merge_with(
+ tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
+
+ known_shapes = nest.map_structure(_set_first_dimension,
+ dataset.output_shapes)
+ return _RestructuredDataset(filtered, dataset.output_types, known_shapes)
+
+ return _apply_fn
+
+
def batch_and_drop_remainder(batch_size):
"""A batching transformation that omits the final small batch (if present).
@@ -135,34 +171,43 @@ def batch_and_drop_remainder(batch_size):
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- tensor_batch_size = ops.convert_to_tensor(
- batch_size, dtype=dtypes.int64, name="batch_size")
+ batched = dataset.batch(batch_size)
+ return filter_irregular_batches(batch_size)(batched)
- batched = dataset.batch(tensor_batch_size)
- flattened = _RestructuredDataset(batched,
- tuple(nest.flatten(batched.output_types)))
+ return _apply_fn
- def _predicate(*xs):
- """Return `True` if this element is a full batch."""
- # Extract the dynamic batch size from the first component of the flattened
- # batched element.
- first_component = xs[0]
- first_component_batch_size = array_ops.shape(
- first_component, out_type=dtypes.int64)[0]
- return math_ops.equal(first_component_batch_size, tensor_batch_size)
+def padded_batch_and_drop_remainder(batch_size,
+ padded_shapes,
+ padding_values=None):
+ """A batching and padding transformation that omits the final small batch.
- filtered = flattened.filter(_predicate)
+ Like @{tf.data.Dataset.padded_batch}, this transformation combines
+ consecutive elements of this dataset into batches. However, if the batch
+ size does not evenly divide the input dataset size, this transformation will
+ drop the final smaller element.
- maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
+ See `@{tf.contrib.data.batch_and_drop_remainder}` for more details.
- def _set_first_dimension(shape):
- return shape.merge_with(
- tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
+ Args:
+ batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
+ consecutive elements of this dataset to combine in a single batch.
+ padded_shapes: A nested structure of `tf.TensorShape` or
+ `tf.int64` vector tensor-like objects. See
+ @{tf.data.Dataset.padded_batch} for details.
+ padding_values: (Optional.) A nested structure of scalar-shaped
+ `tf.Tensor`. See @{tf.data.Dataset.padded_batch} for details.
- known_shapes = nest.map_structure(_set_first_dimension,
- batched.output_shapes)
- return _RestructuredDataset(filtered, batched.output_types, known_shapes)
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ batched = dataset.padded_batch(
+ batch_size, padded_shapes=padded_shapes, padding_values=padding_values)
+ return filter_irregular_batches(batch_size)(batched)
return _apply_fn
diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py
index c4c4426809..45d6dbe743 100644
--- a/tensorflow/contrib/data/python/ops/dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/dataset_ops.py
@@ -20,21 +20,15 @@ from __future__ import print_function
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import enumerate_ops
from tensorflow.contrib.data.python.ops import error_ops
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import grouping
-from tensorflow.contrib.util import loader
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_io_ops
-from tensorflow.python.platform import resource_loader
from tensorflow.python.util import deprecation
-_dataset_ops = loader.load_op_library(
- resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
-
-
class Dataset(dataset_ops.Dataset):
"""Represents a potentially large set of elements.
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index 51a2791072..238bb52b02 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -17,9 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
+from tensorflow.python.ops import gen_dataset_ops
def ignore_errors():
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 1c7c94b3c8..6df7b22fb6 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -17,12 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
def group_by_window(key_func,
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index ce23e95697..74a919c1ff 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -17,12 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import deprecation
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py
index 32d2f42c93..d736029fb0 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops.py
+++ b/tensorflow/contrib/data/python/ops/iterator_ops.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.training import saver
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index f22298b757..2e1c3153ca 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.data.util import nest
@@ -26,6 +25,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import deprecation
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index 87bbbb7d19..5acaed48a3 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -19,11 +19,11 @@ from __future__ import print_function
import collections
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
class _ScanDataset(dataset_ops.Dataset):
diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md
index ae4b07799f..dcc370cd00 100644
--- a/tensorflow/contrib/eager/README.md
+++ b/tensorflow/contrib/eager/README.md
@@ -1,4 +1,4 @@
-# TensorFlow Eager Execution
+# Eager Execution
> *WARNING*: This is a preview/pre-alpha version. The API and performance
> characteristics are subject to change.
diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py
index c6e628b074..1a5c6e8aec 100644
--- a/tensorflow/contrib/eager/python/network.py
+++ b/tensorflow/contrib/eager/python/network.py
@@ -244,6 +244,12 @@ class Network(base.Layer):
self._owned_layers = {}
# The scope to use if we end up without a parent.
self._default_parent_variable_scope = variable_scope.get_variable_scope()
+ # Hold on to the variable scope counts from init to check whether a scope
+ # with the name we want was ever created in our parent scope. Without this
+ # check we might have name collisions if the parent scope on init gets
+ # closed before build is called.
+ self._variable_scope_counts_on_init = (
+ variable_scope._get_default_variable_store().variable_scopes_count)
self._custom_getter, self._deferred_restorations = (
_make_custom_getter_for_deferred_restorations())
@@ -261,18 +267,29 @@ class Network(base.Layer):
def _finalize_name(self, parent_network):
if not self._name:
- if not parent_network:
- name_uid_map = base._get_default_graph_uid_map()
- else:
- name_uid_map = parent_network._sub_layer_name_uids
# Were were not passed a name explicitly (or it was blank), so this is an
# anonymous Network. We make up a unique name.
if parent_network:
avoid_names = parent_network._owned_layers
+ name_uid_map = parent_network._sub_layer_name_uids
else:
- avoid_names = None
+ name_uid_map = base._get_default_graph_uid_map()
+ # Figure out which names we have to avoid based on which variable scope
+ # we're nested in.
+ strip_name = self._default_parent_variable_scope.name
+ if strip_name:
+ strip_name += "/"
+ def _strip_on_init_scope(name):
+ if name.startswith(strip_name):
+ return name[len(strip_name):]
+ else:
+ return None
+ avoid_names = set(
+ _strip_on_init_scope(name)
+ for name in self._variable_scope_counts_on_init.keys() if name)
self._name, self._base_name = self._make_unique_name(
- name_uid_map=name_uid_map, avoid_names=avoid_names)
+ name_uid_map=name_uid_map, avoid_names=avoid_names,
+ namespace=self._default_parent_variable_scope.name)
if self._first_parent is None or (self._first_parent # False = no parent
and self._first_parent() is None):
# Save a pointer to the parent Network so that we can later check that the
@@ -302,7 +319,13 @@ class Network(base.Layer):
parent_scope = first_parent._scope
else:
parent_scope = self._default_parent_variable_scope
- with variable_scope.variable_scope(parent_scope):
+ with variable_scope.variable_scope(parent_scope) as parent_vs:
+ expected_scope_name = parent_vs.name + "/" + self._name
+ if expected_scope_name in self._variable_scope_counts_on_init:
+ raise ValueError(
+ ("A Network named '%s' already exists (or a variable_scope was "
+ "created with this name). Names must be unique.") % (
+ self._name,))
# Make sure variables with this prefix will be unique.
with variable_scope.variable_scope(
None, use_resource=True, default_name=self._name) as scope:
@@ -319,25 +342,22 @@ class Network(base.Layer):
"created with this name). Names must be unique.") % (
self._name,))
if (first_parent
- and scope_prefix[:-1] != first_parent._scope.name):
+ and scope_prefix[:-1] != first_parent.scope_name):
raise ValueError(
("Network variable names must match a nesting of sub-Network "
"names. Expected prefix '%s' from parent network, but got "
"'%s' when attempting to create a variable_scope for Network "
"'%s'. Likely an explicit variable_scope was inserted into "
"the nesting.") % (
- first_parent._scope.name,
+ first_parent.scope_name,
scope_prefix[:-1],
self._name))
elif not first_parent and scope_prefix:
# For the case when this Network is not nested inside any other
- # Network, but is in a variable_scope. This is an error for now.
- raise ValueError(
- "Creating Networks inside named variable_scopes is currently "
- "not supported (to ensure that variable names match the names "
- "of Networks in which they were first created). To set "
- "options, try `with tf.variable_scope(''):`. If this "
- "limitation bothers you, please file a feature request.")
+ # Network, but is in a variable_scope. This Network's name takes on
+ # the full variable scope prefix.
+ self._name = scope_name
+
for non_network_sublayer in self._non_network_sublayers:
self._set_scope_for_nonnetwork_sublayer(non_network_sublayer)
@@ -355,8 +375,7 @@ class Network(base.Layer):
raise ValueError(
("The parent of a Layer added to Network %s was garbage collected "
"before the Layer was built. If this limitation bothers you "
- "please, comment on "
- "https://github.com/tensorflow/tensorflow/issues/14164.") %
+ "please file a feature request.") %
(self.name,))
with variable_scope.variable_scope(parent_scope):
# Horrid hack to make Layer variable names which are direct
@@ -420,7 +439,9 @@ class Network(base.Layer):
# name, and we should respect it (subject to error checking).
layer._name, layer._base_name = layer._make_unique_name(
name_uid_map=self._sub_layer_name_uids,
- avoid_names=self._owned_layers)
+ avoid_names=self._owned_layers
+ # No namespace required, since we've specified our own UID map.
+ )
layer._first_parent = weakref.ref(self)
self._non_network_sublayers.append(layer)
if (not layer.built
@@ -556,7 +577,7 @@ class Network(base.Layer):
if os.path.isdir(save_path):
# If we were passed a directory, default to naming based on the Network
# name.
- save_path = os.path.join(save_path, self.name)
+ save_path = os.path.join(save_path, self.name.replace("/", "_"))
user_map_func = map_func
if map_func is None:
map_func = _make_prefix_stripping_map_fn(self.scope_name)
@@ -750,7 +771,7 @@ class Network(base.Layer):
self._set_scope() # scope_name should be available to map_funcs
if os.path.isdir(save_path):
# If we don't have a name yet, set no parent.
- save_path = os.path.join(save_path, self.name)
+ save_path = os.path.join(save_path, self.name.replace("/", "_"))
user_map_func = map_func
if map_func is None:
map_func = _make_prefix_stripping_map_fn(self.scope_name)
diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py
index 14adbafe57..1127055c05 100644
--- a/tensorflow/contrib/eager/python/network_test.py
+++ b/tensorflow/contrib/eager/python/network_test.py
@@ -410,19 +410,103 @@ class NetworkTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testWrappingInVariableScope(self):
+ one = constant_op.constant([[1.]])
+ # Naming happens in the order of first build rather than the order of
+ # construction, but for clarity they're the same here and construction is
+ # annotated.
+ outside_net_before = MyNetwork() # name=my_network_1
+ outside_net_before(one)
+ captured_scope = variable_scope.get_variable_scope()
with variable_scope.variable_scope("outside_scope"):
- net = MyNetwork()
- one = constant_op.constant([[1.]])
- with self.assertRaisesRegexp(
- ValueError,
- ("Creating Networks inside named variable_scopes is currently not "
- "supported")):
- net(one)
- # Alternatively, we could re-name the Network to match the variable_scope:
- # self.assertEqual("outside_scope/my_network_1", net.name)
- # self.assertStartsWith(
- # expected_start="outside_scope/my_network_1/dense/",
- # actual=net.trainable_weights[0].name)
+ net1 = MyNetwork() # name=outside_scope/my_network_1
+ net1(one)
+ name_conflict1 = MyNetwork(name="name_conflict") # fine, unique so far
+ name_conflict2 = MyNetwork(name="name_conflict") # error on build
+ with variable_scope.variable_scope("inside_scope"):
+ # No issue here since the name is unique within its scope.
+ name_conflict3 = MyNetwork(name="name_conflict")
+ net2 = MyNetwork() # name=outside_scope/my_network_3 to avoid the
+ # variable_scope my_network_2 below.
+ vs_name_conflict = MyNetwork(name="vs_name_conflict") # conflict below
+ with variable_scope.variable_scope("intervening_scope"):
+ with variable_scope.variable_scope(captured_scope):
+ with variable_scope.variable_scope("outside_scope"):
+ name_conflict4 = MyNetwork(name="name_conflict") # error on build
+ with variable_scope.variable_scope("my_network_2"):
+ pass
+ with variable_scope.variable_scope("vs_name_conflict"):
+ pass
+ net3 = MyNetwork() # name=outside_scope/my_network_4
+ name_conflict1(one)
+ with self.assertRaisesRegexp(
+ ValueError, "named 'name_conflict' already exists"):
+ name_conflict2(one)
+ name_conflict3(one)
+ net2(one)
+ with self.assertRaisesRegexp(
+ ValueError, "or a variable_scope was created with this name"):
+ vs_name_conflict(one)
+ with self.assertRaisesRegexp(
+ ValueError, "named 'name_conflict' already exists"):
+ name_conflict4(one)
+ self.assertEqual("outside_scope/name_conflict",
+ name_conflict1.name)
+ self.assertStartsWith(
+ expected_start="outside_scope/name_conflict/dense_1/",
+ actual=name_conflict1.variables[0].name)
+ self.assertEqual("outside_scope/inside_scope/name_conflict",
+ name_conflict3.name)
+ self.assertStartsWith(
+ expected_start="outside_scope/inside_scope/name_conflict/dense_1/",
+ actual=name_conflict3.variables[0].name)
+ self.assertEqual("outside_scope/my_network_1", net1.name)
+ self.assertStartsWith(
+ expected_start="outside_scope/my_network_1/dense_1/",
+ actual=net1.trainable_weights[0].name)
+ self.assertEqual("outside_scope/my_network_3", net2.name)
+ self.assertStartsWith(
+ expected_start="outside_scope/my_network_3/dense_1/",
+ actual=net2.trainable_weights[0].name)
+ net3(one)
+ self.assertEqual("outside_scope/my_network_4", net3.name)
+ self.assertStartsWith(
+ expected_start="outside_scope/my_network_4/dense_1/",
+ actual=net3.trainable_weights[0].name)
+ outside_net_after = MyNetwork()
+ outside_net_after(one)
+ self.assertEqual("my_network_1", outside_net_before.name)
+ self.assertStartsWith(
+ expected_start="my_network_1/dense_1/",
+ actual=outside_net_before.trainable_weights[0].name)
+ self.assertEqual("my_network_2", outside_net_after.name)
+ self.assertStartsWith(
+ expected_start="my_network_2/dense_1/",
+ actual=outside_net_after.trainable_weights[0].name)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testVariableScopeStripping(self):
+ with variable_scope.variable_scope("scope1"):
+ with variable_scope.variable_scope("scope2"):
+ net = MyNetwork()
+ net(constant_op.constant([[2.0]]))
+ self.evaluate(net.variables[0].assign([[42.]]))
+ self.assertEqual(net.name, "scope1/scope2/my_network_1")
+ self.assertStartsWith(
+ expected_start="scope1/scope2/my_network_1/dense_1/",
+ actual=net.trainable_weights[0].name)
+ save_path = net.save(self.get_temp_dir())
+ self.assertIn("scope1_scope2_my_network_1", save_path)
+ restore_net = MyNetwork()
+ # Delayed restoration
+ restore_net.restore(save_path)
+ restore_net(constant_op.constant([[1.0]]))
+ self.assertAllEqual([[42.]],
+ self.evaluate(restore_net.variables[0]))
+ self.evaluate(restore_net.variables[0].assign([[-1.]]))
+ # Immediate restoration
+ restore_net.restore(save_path)
+ self.assertAllEqual([[42.]],
+ self.evaluate(restore_net.variables[0]))
@test_util.run_in_graph_and_eager_modes()
def testLayerNamesRespected(self):
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 6eb2cfdaca..bc67ef8354 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -204,10 +204,13 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:summary",
"//tensorflow/python/estimator:head",
+ "//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/saved_model:signature_constants",
"@six_archive//:six",
@@ -229,7 +232,7 @@ py_test(
"//tensorflow/python:string_ops",
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:model_fn",
- "//tensorflow/python/ops/losses",
+ "//tensorflow/python/estimator:prediction_keys",
"//tensorflow/python/saved_model:signature_constants",
"//third_party/py/numpy",
"@six_archive//:six",
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index e344ee3c3e..a9311a20f1 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
@@ -48,7 +49,20 @@ def multi_class_head(n_classes,
Uses `sparse_softmax_cross_entropy` loss.
- This head expects to be fed integer labels specifying the class index.
+ The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`.
+ In many applications, the shape is `[batch_size, n_classes]`.
+
+ `labels` must be a dense `Tensor` with shape matching `logits`, namely
+ `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string
+ `Tensor` with values from the vocabulary. If `label_vocabulary` is not given,
+ `labels` must be an integer `Tensor` with values specifying the class index.
+
+ If `weight_column` is specified, weights must be of shape
+ `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
+
+ The loss is the weighted sum over the input dimensions. Namely, if the input
+ labels have shape `[batch_size, 1]`, the loss is the weighted sum over
+ `batch_size`.
Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
@@ -57,11 +71,11 @@ def multi_class_head(n_classes,
`tf.feature_column.numeric_column` defining feature column representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example.
- label_vocabulary: A list of strings represents possible label values. If it
- is not given, that means labels are already encoded as integer within
- [0, n_classes). If given, labels must be string type and have any value in
- `label_vocabulary`. Also there will be errors if vocabulary is not
- provided and labels are string.
+ label_vocabulary: A list or tuple of strings representing possible label
+ values. If it is not given, that means labels are already encoded as an
+ integer within [0, n_classes). If given, labels must be of string type and
+ have any value in `label_vocabulary`. Note that errors will be raised if
+ `label_vocabulary` is not provided but labels are strings.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -84,7 +98,20 @@ def binary_classification_head(
This head uses `sigmoid_cross_entropy_with_logits` loss.
- This head expects to be fed float labels of shape `(batch_size, 1)`.
+ The head expects `logits` with shape `[D0, D1, ... DN, 1]`.
+ In many applications, the shape is `[batch_size, 1]`.
+
+ `labels` must be a dense `Tensor` with shape matching `logits`, namely
+ `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string
+ `Tensor` with values from the vocabulary. If `label_vocabulary` is not given,
+ `labels` must be float `Tensor` with values in the interval `[0, 1]`.
+
+ If `weight_column` is specified, weights must be of shape
+ `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
+
+ The loss is the weighted sum over the input dimensions. Namely, if the input
+ labels have shape `[batch_size, 1]`, the loss is the weighted sum over
+ `batch_size`.
Args:
weight_column: A string or a `_NumericColumn` created by
@@ -96,11 +123,11 @@ def binary_classification_head(
generated for each threshold value. This threshold is applied to the
logistic values to determine the binary classification (i.e., above the
threshold is `true`, below is `false`.
- label_vocabulary: A list of strings represents possible label values. If it
- is not given, that means labels are already encoded within [0, 1]. If
- given, labels must be string type and have any value in
- `label_vocabulary`. Also there will be errors if vocabulary is not
- provided and labels are string.
+ label_vocabulary: A list or tuple of strings representing possible label
+ values. If it is not given, labels must be float with values within
+ [0, 1]. If given, labels must be string type and have any value in
+ `label_vocabulary`. Note that errors will be raised if `label_vocabulary`
+ is not provided but labels are strings.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -120,9 +147,22 @@ def binary_classification_head(
def regression_head(weight_column=None,
label_dimension=1,
name=None):
- """Creates a `_Head` for regression using the mean squared loss.
+ """Creates a `_Head` for regression using the `mean_squared_error` loss.
+
+ The loss is the weighted sum over all input dimensions. Namely, if the input
+ labels have shape `[batch_size, label_dimension]`, the loss is the weighted
+ sum over both `batch_size` and `label_dimension`.
+
+ The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`.
+ In many applications, the shape is `[batch_size, label_dimension]`.
+
+ The `labels` shape must match `logits`, namely
+ `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape
+ `[D0, D1, ... DN]` is also supported.
- Uses `mean_squared_error` loss.
+ If `weight_column` is specified, weights must be of shape
+ `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or
+ `[D0, D1, ... DN, label_dimension]`.
Args:
weight_column: A string or a `_NumericColumn` created by
@@ -156,15 +196,29 @@ def multi_label_head(n_classes,
or more associated labels, from a discrete set. This is distinct from
`multi_class_head` which has exactly one label per example.
- Uses `sigmoid_cross_entropy` loss averaged over classes. Expects labels as a
- multi-hot tensor of shape `[batch_size, n_classes]`, or as an integer
- `SparseTensor` of class indices.
+ Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over
+ the batch. Namely, if the input logits have shape `[batch_size, n_classes]`,
+ the loss is the average over `n_classes` and the weighted sum over
+ `batch_size`.
+
+ The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many
+ applications, the shape is `[batch_size, label_n_classes]`.
+
+ Labels can be:
+ * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`
+ * An integer `SparseTensor` of class indices. The `dense_shape` must be
+ `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.
+ * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape`
+ must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`.
+
+ If `weight_column` is specified, weights must be of shape
+ `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
`(labels, logits, features)` as arguments and returns unreduced loss with
- shape `[batch_size, 1]`. `loss_fn` must support indicator `labels` with shape
- `[batch_size, n_classes]`. Namely, the head applies `label_vocabulary` to the
- input labels before passing them to `loss_fn`.
+ shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with
+ shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies
+ `label_vocabulary` to the input labels before passing them to `loss_fn`.
Args:
n_classes: Number of classes, must be greater than 1 (for 1 class, use
@@ -191,7 +245,7 @@ def multi_label_head(n_classes,
An instance of `_Head` for multi-label classification.
Raises:
- ValueError: if `n_classes` or `thresholds` is invalid.
+ ValueError: if `n_classes`, `thresholds`, or `loss_fn` is invalid.
"""
thresholds = tuple(thresholds) if thresholds else tuple()
if n_classes is None or n_classes < 2:
@@ -259,26 +313,36 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
indices=labels.indices,
values=label_ids_values,
dense_shape=labels.dense_shape)
+ return math_ops.to_int64(
+ sparse_ops.sparse_to_indicator(label_ids, self._n_classes))
else:
- label_ids = labels
- return math_ops.to_int64(
- sparse_ops.sparse_to_indicator(label_ids, self._n_classes))
- msg = ('labels shape must be [batch_size, {}]. '
- 'Given: ').format(self._n_classes)
- labels_shape = array_ops.shape(labels)
- check_rank_op = control_flow_ops.Assert(
- math_ops.equal(array_ops.rank(labels), 2),
- data=[msg, labels_shape])
- check_label_dim = control_flow_ops.Assert(
- math_ops.equal(labels_shape[-1], self._n_classes),
- data=[msg, labels_shape])
- with ops.control_dependencies([check_rank_op, check_label_dim]):
- return array_ops.identity(labels)
+ err_msg = (
+ r'labels must be an integer SparseTensor with values in '
+ r'[0, {})'.format(self._n_classes))
+ assert_int = check_ops.assert_integer(
+ labels.values, message=err_msg)
+ assert_less = check_ops.assert_less(
+ labels.values,
+ ops.convert_to_tensor(self._n_classes, dtype=labels.dtype),
+ message=err_msg)
+ assert_greater = check_ops.assert_non_negative(
+ labels.values, message=err_msg)
+ with ops.control_dependencies(
+ [assert_int, assert_less, assert_greater]):
+ return math_ops.to_int64(
+ sparse_ops.sparse_to_indicator(labels, self._n_classes))
+ err_msg = (
+ r'labels must be an integer indicator Tensor with values in [0, 1]')
+ return head_lib._assert_range(labels, 2, message=err_msg) # pylint:disable=protected-access,
def create_loss(self, features, mode, logits, labels):
"""See `Head`."""
del mode # Unused for this head.
+ logits = ops.convert_to_tensor(logits)
processed_labels = self._process_labels(labels)
+ processed_labels = head_lib._check_dense_labels_match_logits_and_reshape( # pylint:disable=protected-access
+ labels=processed_labels, logits=logits,
+ expected_labels_dimension=self.logits_dimension)
if self._loss_fn:
unweighted_loss = _call_loss_fn(
loss_fn=self._loss_fn, labels=processed_labels, logits=logits,
@@ -290,7 +354,8 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
# Averages loss over classes.
unweighted_loss = math_ops.reduce_mean(
unweighted_loss, axis=-1, keep_dims=True)
- weights = head_lib._weights(features, self._weight_column) # pylint:disable=protected-access,
+ weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access,
+ features=features, weight_column=self._weight_column, logits=logits)
weighted_sum_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
# _weights() can return 1.
@@ -305,7 +370,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
self, features, mode, logits, labels=None, train_op_fn=None):
"""See `Head`."""
with ops.name_scope(self._name, 'head'):
- logits = head_lib._check_logits(logits, self.logits_dimension) # pylint:disable=protected-access
+ logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access
# Predict.
pred_keys = prediction_keys.PredictionKeys
@@ -335,6 +400,8 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
# Eval.
if mode == model_fn.ModeKeys.EVAL:
+ weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access,
+ features=features, weight_column=self._weight_column, logits=logits)
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
@@ -342,7 +409,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
eval_metric_ops=self._eval_metric_ops(
labels=processed_labels,
probabilities=probabilities,
- weights=head_lib._weights(features, self._weight_column), # pylint:disable=protected-access,
+ weights=weights,
weighted_sum_loss=weighted_sum_loss,
example_weight_sum=example_weight_sum))
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index fd8c53f6a9..d1cf909004 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -316,13 +316,14 @@ class MultiLabelHead(test.TestCase):
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
- r'labels shape must be \[batch_size, 2\]\. Given: \] \[2 1\]'):
+ r'\[expected_labels_shape: \] \[2 2\] \[labels_shape: \] \[2 1\]'):
actual_weighted_sum_loss.eval({
labels_placeholder: np.array([[1], [1]], dtype=np.int64)
})
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
- r'labels shape must be \[batch_size, 2\]\. Given: \] \[2\]'):
+ r'labels shape must be \[D0, D1, ... DN, 2\]\..*'
+ r'\[Received shape: \] \[2\]'):
actual_weighted_sum_loss.eval({
labels_placeholder: np.array([1, 1], dtype=np.int64)
})
@@ -387,9 +388,11 @@ class MultiLabelHead(test.TestCase):
logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
labels=None)
- def _test_eval(self, head, logits, labels, expected_loss, expected_metrics):
+ def _test_eval(
+ self, head, logits, labels, expected_loss, expected_metrics,
+ features=None):
spec = head.create_estimator_spec(
- features={'x': np.array(((42,),), dtype=np.int32)},
+ features=features or {},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)
@@ -655,6 +658,54 @@ class MultiLabelHead(test.TestCase):
labels=None,
train_op_fn=_no_op_train_fn)
+ def test_train_invalid_indicator_labels(self):
+ head = head_lib.multi_label_head(n_classes=2)
+ logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
+ # The value 2 is outside the allowed range.
+ labels = np.array([[2, 0], [1, 1]], dtype=np.int64)
+ def _train_op_fn(loss):
+ del loss
+ return control_flow_ops.no_op()
+
+ spec = head.create_estimator_spec(
+ features={},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn)
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'labels must be an integer indicator Tensor with values in '
+ r'\[0, 1\]'):
+ sess.run(spec.loss)
+
+ def test_train_invalid_sparse_labels(self):
+ head = head_lib.multi_label_head(n_classes=2)
+ logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
+ # The value 2 is outside the allowed range.
+ labels = sparse_tensor.SparseTensor(
+ values=[2, 0, 1],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ def _train_op_fn(loss):
+ del loss
+ return control_flow_ops.no_op()
+
+ spec = head.create_estimator_spec(
+ features={},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn)
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'labels must be an integer SparseTensor with values in \[0, 2\)'):
+ sess.run(spec.loss)
+
def _test_train(self, head, logits, labels, expected_loss):
expected_train_result = 'my_train_op'
def _train_op_fn(loss):
@@ -791,6 +842,153 @@ class MultiLabelHead(test.TestCase):
metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3,
}, summary_str, tol)
+ def test_multi_dim_weighted_train_create_loss(self):
+ """Logits and labels of shape [2, 2, 3], weights [2, 2]."""
+ head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
+
+ logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
+ [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
+ labels = np.array([[[1, 0, 0], [1, 0, 0]],
+ [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
+ weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)
+ # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3
+ # = [[20/3, 10/3], [4, 8]]
+ # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
+ expected_weighted_sum_loss = 39.6667
+ expected_example_weight_sum = np.sum(weights)
+ actual_weighted_sum_loss, actual_example_weight_sum, _ = head.create_loss(
+ features={'weights': weights},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels)
+ atol = 1.e-3
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ self.assertAllClose(
+ expected_weighted_sum_loss, actual_weighted_sum_loss.eval(),
+ atol=atol)
+ self.assertAllClose(
+ expected_example_weight_sum, actual_example_weight_sum.eval(),
+ atol=atol)
+
+ def test_multi_dim_weighted_train(self):
+ """Logits and labels of shape [2, 2, 3], weights [2, 2]."""
+ head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
+
+ logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
+ [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
+ labels = np.array([[[1, 0, 0], [1, 0, 0]],
+ [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
+ weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)
+ # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3
+ # = [[20/3, 10/3], [4, 8]]
+ # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
+ expected_loss = 39.6667
+ expected_train_result = 'my_train_op'
+ def _train_op_fn(loss):
+ return string_ops.string_join(
+ [constant_op.constant(expected_train_result),
+ string_ops.as_string(loss, precision=3)])
+
+ spec = head.create_estimator_spec(
+ features={'weights': weights},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn)
+
+ atol = 1.e-3
+ with self.test_session() as sess:
+ _initialize_variables(self, monitored_session.Scaffold())
+ loss, train_result = sess.run((spec.loss, spec.train_op))
+ self.assertAllClose(expected_loss, loss, atol=atol)
+ self.assertEqual(
+ six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
+ train_result)
+
+ def test_multi_dim_weights_wrong_inner_dim(self):
+ """Logits and labels of shape [2, 2, 3], weights [2, 1]."""
+ head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
+
+ logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
+ [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
+ labels = np.array([[[1, 0, 0], [1, 0, 0]],
+ [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
+ weights = np.array([[1.], [2.]], dtype=np.float32)
+ def _train_op_fn(loss):
+ del loss
+ return control_flow_ops.no_op()
+
+ spec = head.create_estimator_spec(
+ features={'weights': weights},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn)
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 1\]'):
+ spec.loss.eval()
+
+ def test_multi_dim_weights_wrong_outer_dim(self):
+ """Logits and labels of shape [2, 2, 3], weights [2, 2, 3]."""
+ head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
+
+ logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
+ [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
+ labels = np.array([[[1, 0, 0], [1, 0, 0]],
+ [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
+ weights = np.array([[[1., 1., 1.], [1.5, 1.5, 1.5]],
+ [[2., 2., 2.], [2.5, 2.5, 2.5]]], dtype=np.float32)
+ weights_placeholder = array_ops.placeholder(dtype=dtypes.float32)
+ def _train_op_fn(loss):
+ del loss
+ return control_flow_ops.no_op()
+
+ spec = head.create_estimator_spec(
+ features={'weights': weights_placeholder},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn)
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 2 3\]'):
+ spec.loss.eval({weights_placeholder: weights})
+
+ def test_multi_dim_weighted_eval(self):
+ """Logits and labels of shape [2, 2, 3], weights [2, 2]."""
+ head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
+
+ logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
+ [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
+ labels = np.array([[[1, 0, 0], [1, 0, 0]],
+ [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
+ weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)
+ # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3
+ # = [[20/3, 10/3], [4, 8]]
+ # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
+ expected_loss = 39.6667
+ keys = metric_keys.MetricKeys
+ expected_metrics = {
+ keys.LOSS_MEAN: expected_loss / np.sum(weights),
+ # auc and auc_pr cannot be reliably calculated for only 4 samples, but
+ # this assert tests that the algorithm remains consistent.
+ keys.AUC: 0.4977,
+ keys.AUC_PR: 0.6645,
+ }
+ self._test_eval(
+ head=head,
+ features={'weights': weights},
+ logits=logits,
+ labels=labels,
+ expected_loss=expected_loss,
+ expected_metrics=expected_metrics)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py
index 69dbfcee62..73bae5acf9 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py
@@ -22,10 +22,13 @@ import six
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import head as head_lib
+from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.summary import summary
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -72,6 +75,23 @@ def multi_head(heads, head_weights=None):
estimator.train(input_fn=input_fn, steps=100)
```
+ Also supports `logits` as a `Tensor` of shape
+ `[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the
+ last dimension and distribute it appropriately among the heads. E.g.:
+
+ ```python
+ def model_fn(features, labels, mode):
+ # Create simple heads and specify head name.
+ head1 = multi_class_head(n_classes=3, name='head1')
+ head2 = binary_classification_head(name='head2')
+ # Create multi-head from two simple heads.
+ head = multi_head([head1, head2])
+ # Create logits for the multihead.
+ logits = logit_fn(logits_dimension=head.logits_dimension)
+ # Return the merged EstimatorSpec
+ return head.create_estimator_spec(..., logits=logits, ...)
+ ```
+
Args:
heads: List or tuple of `_Head` instances. All heads must have `name`
specified. The first head in the list is the default used at serving time.
@@ -161,18 +181,17 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
def create_loss(self, features, mode, logits, labels):
"""See `Head`."""
- # TODO(roumposg): Add support for logits as single Tensor (with
- # _split_logits utility).
- if not isinstance(logits, dict):
- raise ValueError('logits must be a dict. Single Tensor support coming '
- 'soon.')
+ if isinstance(logits, dict):
+ logits_dict = logits
+ else:
+ logits_dict = self._split_logits(logits)
weighted_sum_losses = []
example_weight_sums = []
labels_by_head = {}
for head in self._heads:
(weighted_sum_loss,
example_weight_sum, processed_labels) = head.create_loss(
- features, mode, logits[head.name], labels[head.name])
+ features, mode, logits_dict[head.name], labels[head.name])
weighted_sum_losses.append(weighted_sum_loss)
example_weight_sums.append(example_weight_sum)
labels_by_head[head.name] = processed_labels
@@ -205,10 +224,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
def create_estimator_spec(
self, features, mode, logits, labels=None, train_op_fn=None):
"""See `_Head`."""
- # TODO(roumposg): Add support for logits as single Tensor (with
- # _split_logits utility).
- if not isinstance(logits, dict):
- raise ValueError('logits must be a dict. Given: {}'.format(logits))
+ if isinstance(logits, dict):
+ logits_dict = logits
+ else:
+ logits_dict = self._split_logits(logits)
if labels and not isinstance(labels, dict):
raise ValueError('labels must be a dict. Given: {}'.format(labels))
@@ -219,22 +238,42 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
head.create_estimator_spec(
features=features,
mode=mode,
- logits=logits[head_name],
+ logits=logits_dict[head_name],
labels=labels[head_name] if labels else None,
train_op_fn=_no_op_train_fn))
- # TODO(roumposg): Add LOSS and LOSS_MEAN summaries for the total head-
- # combined loss.
if mode == model_fn.ModeKeys.TRAIN:
if train_op_fn is None:
raise ValueError('train_op_fn can not be None in TRAIN mode.')
- return self._merge_train(all_estimator_spec, train_op_fn)
+ spec = self._merge_train(all_estimator_spec, train_op_fn)
+ with ops.name_scope(''):
+ summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss)
+ return spec
if mode == model_fn.ModeKeys.PREDICT:
return self._merge_predict(all_estimator_spec)
if mode == model_fn.ModeKeys.EVAL:
return self._merge_eval(all_estimator_spec)
raise ValueError('mode={} unrecognized'.format(mode))
+ def _split_logits(self, logits):
+ """Splits logits along the last dimension and returns a dict."""
+ logits_dict = {}
+ with ops.name_scope(None, 'split_logits', values=[logits]):
+ logits = ops.convert_to_tensor(logits)
+ batch_shape = array_ops.shape(logits)[:-1]
+ zeros_like_batch_shape = array_ops.zeros_like(batch_shape)
+ minus_ones_like_batch_shape = -1 * array_ops.ones_like(batch_shape)
+ begin_idx = 0
+ for head in self._heads:
+ begin_tensor = array_ops.concat(
+ [zeros_like_batch_shape, [begin_idx]], axis=0)
+ size_tensor = array_ops.concat(
+ [minus_ones_like_batch_shape, [head.logits_dimension]], axis=0)
+ logits_dict[head.name] = array_ops.slice(
+ logits, begin=begin_tensor, size=size_tensor)
+ begin_idx += head.logits_dimension
+ return logits_dict
+
def _merge_train(self, all_estimator_spec, train_op_fn):
"""Merges list of `EstimatorSpec` for training.
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
index 16177aebd5..8d51a298b2 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
@@ -106,7 +106,8 @@ class MultiHeadTest(test.TestCase):
multi_head = multi_head_lib.multi_head([head1, head2])
self.assertEqual('head1_head2', multi_head.name)
- def test_predict_two_heads(self):
+ def test_predict_two_heads_logits_dict(self):
+ """Tests predict with logits as dict."""
head1 = head_lib.multi_label_head(n_classes=2, name='head1')
head2 = head_lib.multi_label_head(n_classes=3, name='head2')
multi_head = multi_head_lib.multi_head([head1, head2])
@@ -158,6 +159,111 @@ class MultiHeadTest(test.TestCase):
expected_probabilities['head2'],
sess.run(spec.export_outputs['head2'].scores))
+ def test_predict_two_heads_logits_tensor(self):
+ """Tests predict with logits as Tensor."""
+ head1 = head_lib.multi_label_head(n_classes=2, name='head1')
+ head2 = head_lib.multi_label_head(n_classes=3, name='head2')
+ multi_head = multi_head_lib.multi_head([head1, head2])
+
+ logits = np.array(
+ [[-1., 1., 2., -2., 2.], [-1.5, 1., -3., 2., -2.]], dtype=np.float32)
+ expected_logits1 = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)
+ expected_logits2 = np.array([[2., -2., 2.], [-3., 2., -2.]],
+ dtype=np.float32)
+ expected_probabilities = {
+ 'head1': _sigmoid(expected_logits1),
+ 'head2': _sigmoid(expected_logits2),
+ }
+
+ spec = multi_head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.PREDICT,
+ logits=logits)
+
+ self.assertItemsEqual(
+ (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1',
+ 'head2', 'classification/head2', 'predict/head2'),
+ spec.export_outputs.keys())
+
+ # Assert predictions and export_outputs.
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNone(spec.scaffold.summary_op)
+ predictions = sess.run(spec.predictions)
+ self.assertAllClose(
+ expected_logits1,
+ predictions[('head1', prediction_keys.PredictionKeys.LOGITS)])
+ self.assertAllClose(
+ expected_logits2,
+ predictions[('head2', prediction_keys.PredictionKeys.LOGITS)])
+ self.assertAllClose(
+ expected_probabilities['head1'],
+ predictions[('head1', prediction_keys.PredictionKeys.PROBABILITIES)])
+ self.assertAllClose(
+ expected_probabilities['head2'],
+ predictions[('head2', prediction_keys.PredictionKeys.PROBABILITIES)])
+
+ self.assertAllClose(
+ expected_probabilities['head1'],
+ sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores))
+ self.assertAllClose(
+ expected_probabilities['head1'],
+ sess.run(spec.export_outputs['head1'].scores))
+ self.assertAllClose(
+ expected_probabilities['head2'],
+ sess.run(spec.export_outputs['head2'].scores))
+
+ def test_predict_two_heads_logits_tensor_multi_dim(self):
+ """Tests predict with multi-dimensional logits of shape [2, 2, 5]."""
+ head1 = head_lib.regression_head(label_dimension=2, name='head1')
+ head2 = head_lib.regression_head(label_dimension=3, name='head2')
+ multi_head = multi_head_lib.multi_head([head1, head2])
+
+ logits = np.array(
+ [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]],
+ [[-1.5, 1., -3., 2., -2.], [-1.5, 1., -3., 2., -2.]]],
+ dtype=np.float32)
+ expected_logits1 = np.array(
+ [[[-1., 1.], [-1., 1.]],
+ [[-1.5, 1.], [-1.5, 1.]]],
+ dtype=np.float32)
+ expected_logits2 = np.array(
+ [[[2., -2., 2.], [2., -2., 2.]],
+ [[-3., 2., -2.], [-3., 2., -2.]]],
+ dtype=np.float32)
+
+ spec = multi_head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.PREDICT,
+ logits=logits)
+
+ self.assertItemsEqual(
+ (_DEFAULT_SERVING_KEY, 'head1', 'regression/head1', 'predict/head1',
+ 'head2', 'regression/head2', 'predict/head2'),
+ spec.export_outputs.keys())
+
+ # Assert predictions and export_outputs.
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNone(spec.scaffold.summary_op)
+ predictions = sess.run(spec.predictions)
+ self.assertAllClose(
+ expected_logits1,
+ predictions[('head1', prediction_keys.PredictionKeys.PREDICTIONS)])
+ self.assertAllClose(
+ expected_logits2,
+ predictions[('head2', prediction_keys.PredictionKeys.PREDICTIONS)])
+
+ self.assertAllClose(
+ expected_logits1,
+ sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].value))
+ self.assertAllClose(
+ expected_logits1,
+ sess.run(spec.export_outputs['head1'].value))
+ self.assertAllClose(
+ expected_logits2,
+ sess.run(spec.export_outputs['head2'].value))
+
def test_eval_two_heads_with_weights(self):
head1 = head_lib.multi_label_head(n_classes=2, name='head1')
head2 = head_lib.multi_label_head(n_classes=3, name='head2')
@@ -284,6 +390,84 @@ class MultiHeadTest(test.TestCase):
# example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13
self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol)
+ def test_train_create_loss_logits_tensor(self):
+ """Tests create_loss with logits Tensor."""
+ weights1 = np.array([[1.], [2.]], dtype=np.float32)
+ weights2 = np.array([[2.], [3.]])
+ head1 = head_lib.multi_label_head(n_classes=2, name='head1',
+ weight_column='weights1')
+ head2 = head_lib.multi_label_head(n_classes=3, name='head2',
+ weight_column='weights2')
+ multi_head = multi_head_lib.multi_head(
+ [head1, head2], head_weights=[1., 2.])
+
+ logits = np.array([[-10., 10., 20., -20., 20.],
+ [-15., 10., -30., 20., -20.]], dtype=np.float32)
+ labels = {
+ 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
+ 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),
+ }
+ weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss(
+ features={
+ 'x': np.array(((42,),), dtype=np.int32),
+ 'weights1': weights1,
+ 'weights2': weights2
+ },
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels)
+ tol = 1e-3
+ with self.test_session():
+ # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]
+ # = [10, 7.5]
+ # weighted_sum_loss = 1 * 10 + 2 * 7.5 = 25
+ # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]]
+ # = [20, 10]
+ # weighted_sum_loss = 2 * 20 + 3 * 10 = 70
+ # head-weighted merge = 1 * 25 + 2 * 70 = 165
+ self.assertAllClose(165, weighted_sum_loss.eval(), rtol=tol, atol=tol)
+ # example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13
+ self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol)
+
+ def test_train_create_loss_logits_tensor_multi_dim(self):
+ """Tests create_loss with multi-dimensional logits of shape [2, 2, 5]."""
+ head1 = head_lib.regression_head(label_dimension=2, name='head1')
+ head2 = head_lib.regression_head(label_dimension=3, name='head2')
+ multi_head = multi_head_lib.multi_head([head1, head2])
+
+ logits = np.array(
+ [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]],
+ [[-1.5, 1.5, -2., 2., -2.], [-1.5, 1.5, -2., 2., -2.]]],
+ dtype=np.float32)
+ labels = {
+ 'head1': np.array([[[1., 0.], [1., 0.]],
+ [[1.5, 1.5], [1.5, 1.5]]], dtype=np.float32),
+ 'head2': np.array([[[0., 1., 0.], [0., 1., 0.]],
+ [[2., 2., 0.], [2., 2., 0.]]], dtype=np.float32),
+ }
+ # Loss for the first head:
+ # loss1 = (1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 +
+ # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2
+ # = 28
+ # Loss for the second head:
+ # loss2 = (0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 +
+ # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2
+ # = 74
+ expected_weighted_sum_loss = 28. + 74.
+
+ weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss(
+ features={},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels)
+ tol = 1e-3
+ with self.test_session():
+ self.assertAllClose(
+ expected_weighted_sum_loss, weighted_sum_loss.eval(),
+ rtol=tol, atol=tol)
+ self.assertAllClose(
+ 2. * 2. * 5., example_weight_sum.eval(), rtol=tol, atol=tol)
+
def test_train_one_head(self):
head1 = head_lib.multi_label_head(n_classes=2, name='head1')
multi_head = multi_head_lib.multi_head([head1])
@@ -327,6 +511,7 @@ class MultiHeadTest(test.TestCase):
six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
train_result)
_assert_simple_summaries(self, {
+ metric_keys.MetricKeys.LOSS: expected_loss,
metric_keys.MetricKeys.LOSS + '/head1': expected_loss,
# Average loss over examples.
metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2,
@@ -387,6 +572,7 @@ class MultiHeadTest(test.TestCase):
six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
train_result)
_assert_simple_summaries(self, {
+ metric_keys.MetricKeys.LOSS: expected_loss,
metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1,
metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2,
# Average loss over examples.
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
index 7005a647db..0848c5f62f 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
@@ -34,10 +34,12 @@ from tensorflow.python.estimator import util
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import device as framework_device
from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients as gradients_lib
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
@@ -183,10 +185,17 @@ def _split_batch(features, labels, number_of_shards, device):
"""Split input features and labes into batches."""
def split_dictionary(dictionary):
+ """Split a dictionary into shards."""
shards = [{} for _ in range(number_of_shards)]
for name, tensor in six.iteritems(dictionary):
- for i, shard in enumerate(array_ops.split(tensor, number_of_shards)):
- shards[i][name] = shard
+ if isinstance(tensor, sparse_tensor.SparseTensor):
+ for i, shard in enumerate(
+ sparse_ops.sparse_split(
+ sp_input=tensor, num_split=number_of_shards, axis=0)):
+ shards[i][name] = shard
+ else:
+ for i, shard in enumerate(array_ops.split(tensor, number_of_shards)):
+ shards[i][name] = shard
return shards
with ops_lib.name_scope('split_inputs'):
@@ -313,7 +322,17 @@ def _call_optimizer_fn(optimizer_fn, params):
def _compute_sum_on_device(values, device, name=None):
with ops_lib.device(device):
- return math_ops.add_n(values, name=name)
+ if isinstance(values[0], ops_lib.IndexedSlices):
+ if name:
+ raise ValueError('The name {} is not expected to be given to '
+ 'IndexedSlices {}'.format(name, values))
+
+ values_concat = array_ops.concat([v.values for v in values], axis=0)
+ indices_concat = array_ops.concat([v.indices for v in values], axis=0)
+ return ops_lib.IndexedSlices(values_concat, indices_concat,
+ values[0].dense_shape)
+ else:
+ return math_ops.add_n(values, name=name)
def _train_spec(tower_specs,
@@ -338,25 +357,17 @@ def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'):
[spec.loss for spec in tower_specs], aggregation_device,
aggregated_loss_name)
- eval_metric_ops_lists = {}
+ update_ops = []
for tower_spec in tower_specs:
- metrics = tower_spec.eval_metric_ops or {}
- for name, (_, update_op) in six.iteritems(metrics):
- update_ops = eval_metric_ops_lists.setdefault(name, ([]))
+ for name, (_, update_op) in six.iteritems(tower_spec.eval_metric_ops):
update_ops.append(update_op)
+ with ops_lib.control_dependencies(update_ops):
+ reduced_update_op = _reduce_metric_variables(len(tower_specs))
+
eval_metric_ops = {}
for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops):
- with ops_lib.control_dependencies(eval_metric_ops_lists[name]):
- # This operation reduces local variables across all metrics, yet is
- # called for every metric. This is redundant and it's done because
- # it is hard to know what local variables correspond to what metric.
- # Estimator is going to execute all `reduced_update_op`s as part of
- # a group inside a single `Session.run()` call, which will avoid duplicate
- # computation.
- reduced_update_op = _reduce_metric_variables(len(tower_specs))
eval_metric_ops[name] = (metric_tensor, reduced_update_op)
-
estimator_spec['eval_metric_ops'] = eval_metric_ops
return model_fn_lib.EstimatorSpec(**estimator_spec)
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
index 10b47fba5a..7fb1065ac0 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -65,20 +65,35 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
data = np.linspace(
0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
x_data = data.reshape(batch_size, input_dimension)
+ categorical_data = np.random.random_integers(
+ 0, len(x_data), size=len(x_data))
y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
train_input_fn = numpy_io.numpy_input_fn(
- x={'x': x_data},
+ x={'x': x_data,
+ 'categories': categorical_data},
y=y_data,
batch_size=batch_size,
num_epochs=None,
shuffle=True)
eval_input_fn = numpy_io.numpy_input_fn(
- x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False)
+ x={'x': x_data,
+ 'categories': categorical_data},
+ y=y_data,
+ batch_size=batch_size,
+ shuffle=False)
predict_input_fn = numpy_io.numpy_input_fn(
- x={'x': x_data}, batch_size=batch_size, shuffle=False)
+ x={'x': x_data,
+ 'categories': categorical_data},
+ batch_size=batch_size,
+ shuffle=False)
feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))
+ feature_column.numeric_column('x', shape=(input_dimension,)),
+ feature_column.embedding_column(
+ feature_column.categorical_column_with_vocabulary_list(
+ 'categories',
+ vocabulary_list=np.linspace(
+ 0., len(x_data), len(x_data), dtype=np.int64)), 1)
]
estimator = dnn.DNNClassifier(
@@ -90,14 +105,11 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
def optimizer_fn():
return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05)
- # TODO(isaprykin): Switch Estimator to use allow_soft_placement=True
- # during export_savedmodel and then switch this test to replicate over
- # GPUs instead of CPUs.
estimator = estimator_lib.Estimator(
model_fn=replicate_model_fn.replicate_model_fn(
estimator.model_fn,
optimizer_fn,
- devices=['/cpu:0', '/cpu:0', '/cpu:0']),
+ devices=['/gpu:0', '/gpu:1', '/gpu:2']),
model_dir=estimator.model_dir,
config=estimator.config,
params=estimator.params)
@@ -230,6 +242,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
accuracy = session.run(accuracy)
auc = session.run(auc)
+ # loss[i] = features[i] * 10 - labels[i].
# Accuracy is 0.0 (no match) in the first tower.
# Accuracy is 1.0 (match) in the second tower, since the feature
# times weight "c" happened to be equal to the label.
@@ -531,8 +544,7 @@ class EvalSpecTest(test_util.TensorFlowTestCase):
self.assertEqual('/device:CPU:0', auc.device)
session.run([a, b])
- accuracy = session.run(accuracy)
- auc = session.run(auc)
+ accuracy, auc = session.run([accuracy, auc])
self.assertNear((12 - 2) / 12, accuracy, 0.01)
self.assertEqual(0, auc)
@@ -861,7 +873,7 @@ class LocalDeviceSetterTest(test_util.TensorFlowTestCase):
class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
- def test_example(self):
+ def test_vectors(self):
with self.test_session() as session:
total = replicate_model_fn._compute_sum_on_device(
[1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum')
@@ -870,6 +882,68 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
self.assertEqual('test_sum', total.op.name)
self.assertEqual(10.0, session.run(total))
+ def test_tensors(self):
+ with self.test_session() as session:
+ total = replicate_model_fn._compute_sum_on_device(
+ [[1.0, 2.0], [3.0, 4.0]], device='/device:GPU:0', name='test_sum')
+
+ self.assertEqual('/device:GPU:0', total.device)
+ self.assertEqual('test_sum', total.op.name)
+ self.assertAllEqual([4.0, 6.0], session.run(total))
+
+ def test_indexedslices(self):
+ with self.test_session() as session:
+ a = ops_lib.IndexedSlices(
+ constant_op.constant([1.0, 2.0]), [0, 1],
+ dense_shape=constant_op.constant([2]))
+ b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
+
+ total = replicate_model_fn._compute_sum_on_device(
+ [a, b], device='/device:GPU:0')
+
+ self.assertEqual('/device:GPU:0', total.device)
+ self.assertAllEqual([4.0, 6.0],
+ session.run(ops_lib.convert_to_tensor(total)))
+
+ def test_indexedslices_higher_dimensions(self):
+ with self.test_session() as session:
+ a = ops_lib.IndexedSlices(
+ constant_op.constant([[1.0, 5.0], [2.0, 6.0]]), [0, 1],
+ dense_shape=constant_op.constant([2, 4]))
+ b = ops_lib.IndexedSlices(
+ constant_op.constant([[3.0, 7.0], [4.0, 8.0]]), [0, 1])
+
+ total = replicate_model_fn._compute_sum_on_device(
+ [a, b], device='/device:GPU:0')
+
+ self.assertEqual('/device:GPU:0', total.device)
+ self.assertAllEqual([[4.0, 12.0], [6.0, 14.0]],
+ session.run(ops_lib.convert_to_tensor(total)))
+
+ def test_indexedslices_some_dont_overlap(self):
+ with self.test_session() as session:
+ a = ops_lib.IndexedSlices(
+ constant_op.constant([1.0, 2.0]), [0, 3],
+ dense_shape=constant_op.constant([4]))
+ b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
+
+ total = replicate_model_fn._compute_sum_on_device(
+ [a, b], device='/device:GPU:0')
+
+ self.assertEqual('/device:GPU:0', total.device)
+ self.assertAllEqual([4.0, 4.0, 0.0, 2.0],
+ session.run(ops_lib.convert_to_tensor(total)))
+
+ def test_no_name_for_indexslices(self):
+ a = ops_lib.IndexedSlices(
+ constant_op.constant([1.0, 2.0]), [0, 1],
+ dense_shape=constant_op.constant([2]))
+ b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
+
+ with self.assertRaisesRegexp(ValueError, ''):
+ _ = replicate_model_fn._compute_sum_on_device(
+ [a, b], device='/device:GPU:0', name='cant_name_indexslices')
+
class ConcatTensorDictsTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index e89993991a..0824ecf616 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -76,7 +76,7 @@ class GANEstimator(estimator.Estimator):
return logits
# Create GAN estimator.
- gan_estimator = estimator.GANEstimator(
+ gan_estimator = tfgan.estimator.GANEstimator(
model_dir,
generator_fn=generator_fn,
discriminator_fn=discriminator_fn,
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
index 5d86373a23..5b7747b0a1 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
@@ -139,6 +139,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:random_ops",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
index 87339cb059..39ce3e9337 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.kfac.python.ops import loss_functions
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -96,6 +97,22 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
# difficult to say if the output is correct or not...
neg_log_prob = sess.run(neg_log_prob)
+ def testMultiMinibatchRegistration(self):
+ """Ensure this loss function supports registering multiple minibatches."""
+ with ops.Graph().as_default():
+ tower_logits = []
+ loss = None
+ num_towers = 5
+ for _ in range(num_towers):
+ logits = random_ops.random_uniform(shape=[2, 3])
+ tower_logits.append(logits)
+ if loss is None:
+ loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
+ else:
+ loss.register_additional_minibatch(logits)
+ self.assertListEqual(loss.input_minibatches, tower_logits)
+ self.assertEqual(loss.num_registered_minibatches, num_towers)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 4eabb59b3e..7300a7998c 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -448,10 +448,10 @@ class LayerCollection(object):
tf.get_variable_scope().reuse.
Raises:
- ValueError: If reuse=True and name != None.
- ValueError: If reuse=True and seed != None.
- KeyError: If reuse=True and no existing LossFunction with 'name' found.
- KeyError: If reuse=False and existing LossFunction with 'name' found.
+ ValueError: If reuse == True and name == None.
+ ValueError: If reuse == True and seed != None.
+ KeyError: If reuse == True and no existing LossFunction with 'name' found.
+ KeyError: If reuse == False and existing LossFunction with 'name' found.
"""
name = name or self._graph.unique_name(
"register_categorical_predictive_distribution")
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py
index 3cfde7f9ab..e2e5bc3ffe 100644
--- a/tensorflow/contrib/kfac/python/ops/loss_functions.py
+++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py
@@ -56,6 +56,30 @@ class LossFunction(object):
"""The inputs to the loss function (excluding the targets)."""
pass
+ @property
+ def input_minibatches(self):
+ """A `list` of inputs to the loss function, separated by minibatch.
+
+ Typically there will be one minibatch per tower in a multi-tower setup.
+ Returns a list consisting of `self.inputs` by default; `LossFunction`s
+ supporting registering multiple minibatches should override this method.
+
+ Returns:
+ A `list` of `Tensor`s representing
+ """
+ return [self.inputs]
+
+ @property
+ def num_registered_minibatches(self):
+ """Number of minibatches registered for this LossFunction.
+
+ Typically equal to the number of towers in a multi-tower setup.
+
+ Returns:
+ An `int` representing the number of registered minibatches.
+ """
+ return len(self.input_minibatches)
+
def evaluate(self):
"""Evaluate the loss function on the targets."""
if self.targets is not None:
@@ -75,7 +99,6 @@ class LossFunction(object):
Returns:
log probability of each target, summed across all targets.
"""
-
pass
@abc.abstractmethod
@@ -415,8 +438,8 @@ class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss,
array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype),
axis=-1)
output_slice = self._var**-0.5 * ones_slice
- return insert_slice_in_zeros(output_slice, 1,
- int(self._mean.shape[1]), index[0])
+ return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]),
+ index[0])
@property
def fisher_factor_inner_shape(self):
@@ -474,24 +497,23 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
@property
def _fisher_mean(self):
- return 1./self._variance
+ return 1. / self._variance
@property
def _fisher_mean_factor(self):
- return 1./self._scale
+ return 1. / self._scale
@property
def _fisher_var(self):
- return 1./(2*math_ops.square(self._variance))
+ return 1. / (2 * math_ops.square(self._variance))
@property
def _fisher_var_factor(self):
- return 1./(math_ops.sqrt(2.)*self._variance)
+ return 1. / (math_ops.sqrt(2.) * self._variance)
def multiply_fisher(self, vecs):
mean_vec, var_vec = vecs
- return (self._fisher_mean * mean_vec,
- self._fisher_var * var_vec)
+ return (self._fisher_mean * mean_vec, self._fisher_var * var_vec)
def multiply_fisher_factor(self, vecs):
mean_vec, var_vec = self._split(vecs)
@@ -511,8 +533,8 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
# Index corresponds to mean parameter.
mean_slice = self._fisher_mean_factor[:, index]
mean_slice = array_ops.expand_dims(mean_slice, axis=-1)
- mean_output = insert_slice_in_zeros(mean_slice, 1,
- int(self._mean.shape[1]), index)
+ mean_output = insert_slice_in_zeros(mean_slice, 1, int(
+ self._mean.shape[1]), index)
var_output = array_ops.zeros_like(mean_output)
else:
index -= int(self._mean.shape[-1])
@@ -527,13 +549,17 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
@property
def fisher_factor_inner_shape(self):
- return array_ops.concat([array_ops.shape(self._mean)[:-1],
- 2*array_ops.shape(self._mean)[-1:]], axis=0)
+ return array_ops.concat(
+ [
+ array_ops.shape(self._mean)[:-1],
+ 2 * array_ops.shape(self._mean)[-1:]
+ ],
+ axis=0)
@property
def fisher_factor_inner_static_shape(self):
shape = self._mean.shape.as_list()
- return tensor_shape.TensorShape(shape[-1:] + [2*shape[-1]])
+ return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]])
def multiply_hessian(self, vector):
raise NotImplementedError()
@@ -606,6 +632,10 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
return array_ops.concat(self._logits_components, axis=0)
@property
+ def input_minibatches(self):
+ return self._logits_components
+
+ @property
def targets(self):
if all(target is None for target in self._targets_components):
return None
@@ -710,8 +740,8 @@ class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss,
assert len(index) == 1, "Length of index was {}".format(len(index))
probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1)
output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice))
- return insert_slice_in_zeros(output_slice, 1,
- int(self._logits.shape[1]), index[0])
+ return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]),
+ index[0])
@property
def fisher_factor_inner_shape(self):
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 2917a30a17..94920db574 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -22,6 +22,8 @@ py_library(
exclude = ["python/learn/**/*_test.py"],
),
srcs_version = "PY2AND3",
+ # This library should not depend on sklearn, even though some of the code
+ # refers to it. (The code handles the presence of sklearn conditionally.)
deps = [
"//tensorflow/contrib/factorization:factorization_py",
"//tensorflow/contrib/framework:framework_py",
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
new file mode 100644
index 0000000000..c58f77cb11
--- /dev/null
+++ b/tensorflow/contrib/lite/BUILD
@@ -0,0 +1,280 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops")
+
+exports_files(glob([
+ "testdata/*.bin",
+ "models/testdata/*",
+]))
+
+config_setting(
+ name = "mips",
+ values = {
+ "cpu": "mips",
+ },
+)
+
+config_setting(
+ name = "mips64",
+ values = {
+ "cpu": "mips64",
+ },
+)
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+cc_library(
+ name = "schema_fbs_version",
+ hdrs = ["version.h"],
+)
+
+# Main library. No ops are included here.
+# TODO(aselle): Resolve problems preventing C99 usage.
+cc_library(
+ name = "context",
+ srcs = ["context.c"],
+ hdrs = ["context.h"],
+)
+
+cc_library(
+ name = "builtin_op_data",
+ hdrs = [
+ "builtin_op_data.h",
+ ],
+)
+
+cc_library(
+ name = "string",
+ hdrs = [
+ "string.h",
+ ],
+ deps = [
+ "//tensorflow/core:lib_platform",
+ ],
+)
+
+# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts.
+cc_library(
+ name = "framework",
+ srcs = [
+ "allocation.cc",
+ "error_reporter.cc",
+ "interpreter.cc",
+ "model.cc",
+ "nnapi_delegate.cc",
+ "optional_debug_tools.cc",
+ "simple_memory_arena.cc",
+ ],
+ hdrs = [
+ "allocation.h",
+ "context.h",
+ "error_reporter.h",
+ "interpreter.h",
+ "model.h",
+ "nnapi_delegate.h",
+ "optional_debug_tools.h",
+ "simple_memory_arena.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ ":builtin_op_data",
+ ":context",
+ ":schema_fbs_version",
+ "//tensorflow/contrib/lite/kernels:gemm_support",
+ "//tensorflow/contrib/lite/nnapi:nnapi_lib",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/core:lib_platform",
+ ],
+)
+
+cc_library(
+ name = "string_util",
+ srcs = ["string_util.cc"],
+ hdrs = ["string_util.h"],
+ deps = [
+ ":framework",
+ ":string",
+ ],
+)
+
+cc_test(
+ name = "string_util_test",
+ size = "small",
+ srcs = ["string_util_test.cc"],
+ deps = [
+ ":framework",
+ ":string_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Test main interpreter
+cc_test(
+ name = "interpreter_test",
+ size = "small",
+ srcs = ["interpreter_test.cc"],
+ deps = [
+ ":framework",
+ ":string_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Test arena allocator
+cc_test(
+ name = "simple_memory_arena_test",
+ size = "small",
+ srcs = ["simple_memory_arena_test.cc"],
+ deps = [
+ ":framework",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Test model framework.
+cc_test(
+ name = "model_test",
+ size = "small",
+ srcs = ["model_test.cc"],
+ data = [
+ "testdata/0_subgraphs.bin",
+ "testdata/2_subgraphs.bin",
+ "testdata/empty_model.bin",
+ "testdata/test_model.bin",
+ "testdata/test_model_broken.bin",
+ ],
+ deps = [
+ ":framework",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Test the C extension API code.
+cc_test(
+ name = "context_test",
+ size = "small",
+ srcs = ["context_test.cc"],
+ deps = [
+ ":framework",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Test the serialization of a model with optional tensors.
+
+# Model tests
+
+cc_library(
+ name = "models_test_utils",
+ testonly = 1,
+ hdrs = ["models/test_utils.h"],
+ deps = select({
+ "//tensorflow:android": [],
+ "//conditions:default": [
+ #"//file/base:path",
+ "//tensorflow/core:test",
+ ],
+ }),
+)
+
+cc_test(
+ name = "speech_hotword_model_test",
+ size = "small",
+ srcs = ["models/speech_hotword_model_test.cc"],
+ data = [
+ "models/testdata/speech_hotword_model_in.csv",
+ "models/testdata/speech_hotword_model_out_rank1.csv",
+ "models/testdata/speech_hotword_model_out_rank2.csv",
+ "models/testdata/speech_hotword_model_rank1.tflite",
+ "models/testdata/speech_hotword_model_rank2.tflite",
+ ],
+ deps = [
+ ":framework",
+ ":models_test_utils",
+ #"//file/base:path",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+gen_selected_ops(
+ name = "speech_speakerid_ops",
+ model = "models/testdata/speech_speakerid_model.tflite",
+)
+
+cc_test(
+ name = "speech_speakerid_model_test",
+ size = "small",
+ srcs = [
+ "models/speech_speakerid_model_test.cc",
+ ":speech_speakerid_ops",
+ ],
+ data = [
+ "models/testdata/speech_speakerid_model.tflite",
+ "models/testdata/speech_speakerid_model_in.csv",
+ "models/testdata/speech_speakerid_model_out.csv",
+ ],
+ deps = [
+ ":framework",
+ ":models_test_utils",
+ #"//file/base:path",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/tools:mutable_op_resolver",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "speech_terse_am_model_test",
+ size = "small",
+ srcs = ["models/speech_terse_am_model_test.cc"],
+ data = [
+ "models/testdata/speech_terse_am_model.tflite",
+ "models/testdata/speech_terse_am_model_in.csv",
+ "models/testdata/speech_terse_am_model_out.csv",
+ ],
+ deps = [
+ ":framework",
+ ":models_test_utils",
+ #"//file/base:path",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "speech_tts_model_test",
+ size = "small",
+ srcs = ["models/speech_tts_model_test.cc"],
+ data = [
+ "models/testdata/speech_tts_model.tflite",
+ "models/testdata/speech_tts_model_in.csv",
+ "models/testdata/speech_tts_model_out.csv",
+ ],
+ deps = [
+ ":framework",
+ ":models_test_utils",
+ #"//file/base:path",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc
new file mode 100644
index 0000000000..4b322e027d
--- /dev/null
+++ b/tensorflow/contrib/lite/allocation.cc
@@ -0,0 +1,122 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <fcntl.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <cassert>
+#include <cstdarg>
+#include <cstdint>
+#include <cstring>
+
+#include "tensorflow/contrib/lite/allocation.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/nnapi_delegate.h"
+
+namespace tflite {
+
+MMAPAllocation::MMAPAllocation(const char* filename,
+ ErrorReporter* error_reporter)
+ : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) {
+ mmap_fd_ = open(filename, O_RDONLY);
+ if (mmap_fd_ == -1) {
+ error_reporter_->Report("Could not open '%s'.", filename);
+ return;
+ }
+ struct stat sb;
+ fstat(mmap_fd_, &sb);
+ buffer_size_bytes_ = sb.st_size;
+ mmapped_buffer_ =
+ mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0);
+ if (mmapped_buffer_ == MAP_FAILED) {
+ error_reporter_->Report("Mmap of '%s' failed.", filename);
+ return;
+ }
+}
+
+MMAPAllocation::~MMAPAllocation() {
+ if (valid()) {
+ munmap(const_cast<void*>(mmapped_buffer_), buffer_size_bytes_);
+ }
+ if (mmap_fd_ != -1) close(mmap_fd_);
+}
+
+const void* MMAPAllocation::base() const { return mmapped_buffer_; }
+
+size_t MMAPAllocation::bytes() const { return buffer_size_bytes_; }
+
+bool MMAPAllocation::valid() const { return mmapped_buffer_ != MAP_FAILED; }
+
+FileCopyAllocation::FileCopyAllocation(const char* filename,
+ ErrorReporter* error_reporter)
+ : Allocation(error_reporter) {
+ // Obtain the file size, using an alternative method that is does not
+ // require fstat for more compatibility.
+ std::unique_ptr<FILE, decltype(&fclose)> file(fopen(filename, "rb"), fclose);
+ if (!file) {
+ error_reporter_->Report("Could not open '%s'.", filename);
+ return;
+ }
+ // TODO(ahentz): Why did you think using fseek here was better for finding
+ // the size?
+ struct stat sb;
+ if (fstat(fileno(file.get()), &sb) != 0) {
+ error_reporter_->Report("Failed to get file size of '%s'.", filename);
+ return;
+ }
+ buffer_size_bytes_ = sb.st_size;
+ std::unique_ptr<char[]> buffer(new char[buffer_size_bytes_]);
+ if (!buffer) {
+ error_reporter_->Report("Malloc of buffer to hold copy of '%s' failed.",
+ filename);
+ return;
+ }
+ size_t bytes_read =
+ fread(buffer.get(), sizeof(char), buffer_size_bytes_, file.get());
+ if (bytes_read != buffer_size_bytes_) {
+ error_reporter_->Report("Read of '%s' failed (too few bytes read).",
+ filename);
+ return;
+ }
+ copied_buffer_ = std::move(buffer);
+}
+
+FileCopyAllocation::~FileCopyAllocation() {}
+
+const void* FileCopyAllocation::base() const { return copied_buffer_.get(); }
+
+size_t FileCopyAllocation::bytes() const { return buffer_size_bytes_; }
+
+bool FileCopyAllocation::valid() const { return copied_buffer_ != nullptr; }
+
+MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes,
+ ErrorReporter* error_reporter)
+ : Allocation(error_reporter) {
+ buffer_ = ptr;
+ buffer_size_bytes_ = num_bytes;
+}
+
+MemoryAllocation::~MemoryAllocation() {}
+
+const void* MemoryAllocation::base() const { return buffer_; }
+
+size_t MemoryAllocation::bytes() const { return buffer_size_bytes_; }
+
+bool MemoryAllocation::valid() const { return buffer_ != nullptr; }
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h
new file mode 100644
index 0000000000..ee8a7ccd0b
--- /dev/null
+++ b/tensorflow/contrib/lite/allocation.h
@@ -0,0 +1,94 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Main abstraction controlling the tflite interpreter.
+// See context.h for the API for defining operations (TfLiteRegistration).
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_
+
+#include <cstdio>
+#include <cstdlib>
+#include <vector>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/simple_memory_arena.h"
+
+namespace tflite {
+
+// A memory allocation handle. This could be a mmap or shared memory.
+class Allocation {
+ public:
+ Allocation(ErrorReporter* error_reporter) : error_reporter_(error_reporter) {}
+ virtual ~Allocation() {}
+
+ // Base pointer of this allocation
+ virtual const void* base() const = 0;
+ // Size in bytes of the allocation
+ virtual size_t bytes() const = 0;
+ // Whether the allocation is valid
+ virtual bool valid() const = 0;
+
+ protected:
+ ErrorReporter* error_reporter_;
+};
+
+class MMAPAllocation : public Allocation {
+ public:
+ MMAPAllocation(const char* filename, ErrorReporter* error_reporter);
+ virtual ~MMAPAllocation();
+ const void* base() const override;
+ size_t bytes() const override;
+ bool valid() const override;
+
+ protected:
+ // Data required for mmap.
+ int mmap_fd_ = -1; // mmap file descriptor
+ const void* mmapped_buffer_;
+ size_t buffer_size_bytes_ = 0;
+};
+
+class FileCopyAllocation : public Allocation {
+ public:
+ FileCopyAllocation(const char* filename, ErrorReporter* error_reporter);
+ virtual ~FileCopyAllocation();
+ const void* base() const override;
+ size_t bytes() const override;
+ bool valid() const override;
+
+ private:
+ // Data required for mmap.
+ std::unique_ptr<const char[]> copied_buffer_;
+ size_t buffer_size_bytes_ = 0;
+};
+
+class MemoryAllocation : public Allocation {
+ public:
+ // Allocates memory with the pointer and the number of bytes of the memory.
+ // The pointer has to remain alive and unchanged until the destructor is
+ // called.
+ MemoryAllocation(const void* ptr, size_t num_bytes,
+ ErrorReporter* error_reporter);
+ virtual ~MemoryAllocation();
+ const void* base() const override;
+ size_t bytes() const override;
+ bool valid() const override;
+
+ private:
+ const void* buffer_;
+ size_t buffer_size_bytes_ = 0;
+};
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
new file mode 100644
index 0000000000..e3c9cdd99b
--- /dev/null
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -0,0 +1,233 @@
+"""Generate Flatbuffer binary from json."""
+
+def tflite_copts():
+ """Defines compile time flags."""
+ copts = [
+ "-DFARMHASH_NO_CXX_STRING",
+ ] + select({
+ "//tensorflow:android_arm64": [
+ "-std=c++11",
+ "-O3",
+ ],
+ "//tensorflow:android_arm": [
+ "-mfpu=neon",
+ "-mfloat-abi=softfp",
+ "-std=c++11",
+ "-O3",
+ ],
+ "//tensorflow:android_x86": [
+ "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK",
+ ],
+ "//tensorflow:ios_x86_64": [
+ "-msse4.1",
+ ],
+ "//conditions:default": [],
+ }) + select({
+ "//tensorflow:with_default_optimizations": [],
+ "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"],
+ })
+
+ return copts
+
+LINKER_SCRIPT = "//tensorflow/contrib/lite/java/src/main/native:version_script.lds"
+
+def tflite_linkopts_unstripped():
+ """Defines linker flags to reduce size of TFLite binary.
+
+ These are useful when trying to investigate the relative size of the
+ symbols in TFLite.
+
+ Returns:
+ a select object with proper linkopts
+ """
+ return select({
+ "//tensorflow:android": [
+ "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj.
+ "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export.
+ "-Wl,--gc-sections", # Eliminate unused code and data.
+ "-Wl,--as-needed", # Don't link unused libs.
+ ],
+ "//tensorflow/contrib/lite:mips": [],
+ "//tensorflow/contrib/lite:mips64": [],
+ "//conditions:default": [
+ "-Wl,--icf=all", # Identical code folding.
+ ],
+ })
+
+def tflite_jni_linkopts_unstripped():
+ """Defines linker flags to reduce size of TFLite binary with JNI.
+
+ These are useful when trying to investigate the relative size of the
+ symbols in TFLite.
+
+ Returns:
+ a select object with proper linkopts
+ """
+ return select({
+ "//tensorflow:android": [
+ "-Wl,--gc-sections", # Eliminate unused code and data.
+ "-Wl,--as-needed", # Don't link unused libs.
+ ],
+ "//tensorflow/contrib/lite:mips": [],
+ "//tensorflow/contrib/lite:mips64": [],
+ "//conditions:default": [
+ "-Wl,--icf=all", # Identical code folding.
+ ],
+ })
+
+def tflite_linkopts():
+ """Defines linker flags to reduce size of TFLite binary."""
+ return tflite_linkopts_unstripped() + select({
+ "//tensorflow:android": [
+ "-s", # Omit symbol table.
+ ],
+ "//conditions:default": [],
+ })
+
+def tflite_jni_linkopts():
+ """Defines linker flags to reduce size of TFLite binary with JNI."""
+ return tflite_jni_linkopts_unstripped() + select({
+ "//tensorflow:android": [
+ "-s", # Omit symbol table.
+ ],
+ "//conditions:default": [],
+ })
+
+
+def tflite_jni_binary(name,
+ copts=tflite_copts(),
+ linkopts=tflite_jni_linkopts(),
+ linkscript=LINKER_SCRIPT,
+ linkshared=1,
+ linkstatic=1,
+ deps=[]):
+ """Builds a jni binary for TFLite."""
+ linkopts = linkopts + [
+ "-Wl,--version-script", # Export only jni functions & classes.
+ linkscript,
+ ]
+ native.cc_binary(
+ name=name,
+ copts=copts,
+ linkshared=linkshared,
+ linkstatic=linkstatic,
+ deps= deps + [linkscript],
+ linkopts=linkopts)
+
+def tf_to_tflite(name, src, options, out):
+ """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer.
+
+ Args:
+ name: Name of rule.
+ src: name of the input graphdef file.
+ options: options passed to TOCO.
+ out: name of the output flatbuffer file.
+ """
+
+ toco = "//tensorflow/contrib/lite/toco:toco"
+ native.genrule(
+ name = name,
+ srcs=[src, options],
+ outs=[out],
+ cmd = ("$(location %s) " +
+ " --input_file=$(location %s) " +
+ " --output_file=$(location %s) " +
+ " --input_format=TENSORFLOW_GRAPHDEF" +
+ " --output_format=TFLITE" +
+ " `cat $(location %s)`")
+ % (toco, src, out, options),
+ tools= [toco],
+ )
+
+def tflite_to_json(name, src, out):
+ """Convert a TF Lite flatbuffer to JSON.
+
+ Args:
+ name: Name of rule.
+ src: name of the input flatbuffer file.
+ out: name of the output JSON file.
+ """
+
+ flatc = "@flatbuffers//:flatc"
+ schema = "//tensorflow/contrib/lite/schema:schema.fbs"
+ native.genrule(
+ name = name,
+ srcs = [schema, src],
+ outs = [out],
+ cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" +
+ "$(location %s) --raw-binary --strict-json -t" +
+ " -o /tmp $(location %s) -- $${TMP}.bin &&" +
+ "cp $${TMP}.json $(location %s)")
+ % (src, flatc, schema, out),
+ tools = [flatc],
+ )
+
+def json_to_tflite(name, src, out):
+ """Convert a JSON file to TF Lite's flatbuffer.
+
+ Args:
+ name: Name of rule.
+ src: name of the input JSON file.
+ out: name of the output flatbuffer file.
+ """
+
+ flatc = "@flatbuffers//:flatc"
+ schema = "//tensorflow/contrib/lite/schema:schema_fbs"
+ native.genrule(
+ name = name,
+ srcs = [schema, src],
+ outs = [out],
+ cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" +
+ "$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" +
+ " -o /tmp $(location %s) $${TMP}.json &&" +
+ "cp $${TMP}.bin $(location %s)")
+ % (src, flatc, schema, out),
+ tools = [flatc],
+ )
+
+def gen_zipped_test_files(name, files):
+ """Generate a zip file of tests by using :generate_examples.
+
+ Args:
+ name: Name of output. We will produce "`name`_files" as a target.
+ files: A list of zip file basenames.
+ """
+ toco = "//tensorflow/contrib/lite/toco:toco"
+ out_files = []
+ for f in files:
+ out_file = name + "/" + f
+ out_files.append(out_file)
+ native.genrule(
+ name = name + "_" + f + ".files",
+ cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco
+ + " --zip_to_output " + f +
+ " $(@D) zipped"),
+ outs = [out_file],
+ tools = [
+ ":generate_examples",
+ toco,
+ ],
+ )
+
+ native.filegroup(
+ name = name,
+ srcs = out_files,
+ )
+
+def gen_selected_ops(name, model):
+ """Generate the library that includes only used ops.
+
+ Args:
+ name: Name of the generated library.
+ model: TFLite model to interpret.
+ """
+ out = name + "_registration.cc"
+ tool = "//tensorflow/contrib/lite/tools:generate_op_registrations"
+ native.genrule(
+ name = name,
+ srcs = [model],
+ outs = [out],
+ cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s)")
+ % (tool, model, out),
+ tools = [tool],
+ )
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
new file mode 100644
index 0000000000..93072bf90b
--- /dev/null
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -0,0 +1,164 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
+
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// TODO(aselle): Consider using "if this then that" for testing.
+
+// Possible padding types (for convolutions)
+typedef enum {
+ kTfLitePaddingUnknown = 0,
+ kTfLitePaddingSame,
+ kTfLitePaddingValid,
+} TfLitePadding;
+
+typedef struct {
+ int width;
+ int height;
+} TfLitePaddingValues;
+
+// Possible fused activation functions.
+// TODO(aselle): rename to TfLiteActivation
+typedef enum {
+ kTfLiteActNone = 0,
+ kTfLiteActRelu,
+ kTfLiteActRelu1,
+ kTfLiteActRelu6,
+ kTfLiteActTanh,
+ kTfLiteActSignBit,
+ kTfLiteActSigmoid,
+} TfLiteFusedActivation;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ TfLiteFusedActivation activation;
+} TfLiteConvParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int filter_width;
+ int filter_height;
+ TfLiteFusedActivation activation;
+ struct {
+ TfLitePaddingValues padding;
+ } computed;
+} TfLitePoolParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int depth_multiplier;
+ TfLiteFusedActivation activation;
+} TfLiteDepthwiseConvParams;
+
+typedef struct {
+ int rank;
+ TfLiteFusedActivation activation;
+} TfLiteSVDFParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteRNNParams;
+
+typedef struct { TfLiteFusedActivation activation; } TfLiteFullyConnectedParams;
+
+typedef enum {
+ kTfLiteLshProjectionUnknown = 0,
+ kTfLiteLshProjectionSparse = 1,
+ kTfLiteLshProjectionDense = 2,
+} TfLiteLSHProjectionType;
+
+typedef struct { TfLiteLSHProjectionType type; } TfLiteLSHProjectionParams;
+
+typedef struct { float beta; } TfLiteSoftmaxParams;
+
+typedef struct {
+ int axis;
+ TfLiteFusedActivation activation;
+} TfLiteConcatenationParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteAddParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteMulParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteL2NormParams;
+
+typedef struct {
+ int radius;
+ float bias;
+ float alpha;
+ float beta;
+} TfLiteLocalResponseNormParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
+} TfLiteLSTMParams;
+
+typedef struct {
+ int new_height;
+ int new_width;
+} TfLiteResizeBilinearParams;
+
+typedef struct {
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int shape[8];
+ int num_dimensions;
+} TfLiteReshapeParams;
+
+typedef struct {
+ int ngram_size;
+ int max_skip_size;
+ bool include_all_ngrams;
+} TfLiteSkipGramParams;
+
+typedef struct {
+ int block_size;
+} TfLiteSpaceToDepthParams;
+
+typedef enum {
+ kTfLiteCombinerTypeSum = 0,
+ kTfLiteCombinerTypeMean = 1,
+ kTfLiteCombinerTypeSqrtn = 2,
+} TfLiteCombinerType;
+
+typedef struct {
+ TfLiteCombinerType combiner;
+} TfLiteEmbeddingLookupSparseParams;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c
new file mode 100644
index 0000000000..c09e838c5c
--- /dev/null
+++ b/tensorflow/contrib/lite/context.c
@@ -0,0 +1,92 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/context.h"
+#include <stdio.h>
+#include <string.h>
+
+TfLiteIntArray* TfLiteIntArrayCreate(int size) {
+ TfLiteIntArray* ret =
+ (TfLiteIntArray*)malloc(sizeof(*ret) + sizeof(ret->data[0]) * size);
+ ret->size = size;
+ return ret;
+}
+
+void TfLiteIntArrayPrint(const char* s, TfLiteIntArray* a) {
+ printf("%s: length=%d [", s, a->size);
+ if (a->size) printf("%d", a->data[0]);
+ int i = 1;
+ for (; i < a->size; i++) {
+ printf(" %d", a->data[i]);
+ }
+ printf("]\n");
+}
+
+int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) {
+ if (a == b) return 1;
+ if (a == NULL || b == NULL) return 0;
+ if (a->size != b->size) return 0;
+ int i = 0;
+ for (; i < a->size; i++)
+ if (a->data[i] != b->data[i]) return 0;
+ return 1;
+}
+
+TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src) {
+ if (!src) return NULL;
+ TfLiteIntArray* ret = TfLiteIntArrayCreate(src->size);
+ if (ret) {
+ memcpy(ret->data, src->data, src->size * sizeof(int));
+ }
+ return ret;
+}
+
+void TfLiteIntArrayFree(TfLiteIntArray* a) { free(a); }
+
+void TfLiteTensorFree(TfLiteTensor* t) {
+ if (t->allocation_type == kTfLiteDynamic && t->data.raw) {
+ free(t->data.raw);
+ }
+ if (t->dims) TfLiteIntArrayFree(t->dims);
+ t->data.raw = NULL;
+ t->dims = NULL;
+}
+
+void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
+ TfLiteQuantizationParams quantization, char* buffer,
+ size_t size, TfLiteAllocationType allocation_type,
+ const void* allocation, TfLiteTensor* tensor) {
+ TfLiteTensorFree(tensor);
+ tensor->type = type;
+ tensor->name = name;
+ tensor->dims = dims;
+ tensor->params = quantization;
+ tensor->data.raw = buffer;
+ tensor->bytes = size;
+ tensor->allocation_type = allocation_type;
+ tensor->allocation = allocation;
+}
+
+void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
+ if (tensor->allocation_type != kTfLiteDynamic) {
+ return;
+ }
+ if (!tensor->data.raw) {
+ tensor->data.raw = malloc(num_bytes);
+ } else if (num_bytes > tensor->bytes) {
+ tensor->data.raw = realloc(tensor->data.raw, num_bytes);
+ }
+ tensor->bytes = num_bytes;
+}
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
new file mode 100644
index 0000000000..41257a53b1
--- /dev/null
+++ b/tensorflow/contrib/lite/context.h
@@ -0,0 +1,298 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This file defines a C API for implementing operations in tflite.
+// These operations can be defined using c++ but the interface between
+// the interpreter and the operations are C.
+//
+// Summary of abstractions
+// TF_LITE_ENSURE - Self-sufficient error checking
+// TfLiteStatus - Status reporting
+// TfLiteIntArray - stores tensor shapes (dims),
+// TfLiteContext - allows an op to access the tensors
+// TfLiteTensor - tensor (a multidimensional array)
+// TfLiteNode - a single node or operation
+// TfLiteRegistration - the implementation of a conceptual operation.
+//
+// Some abstractions in this file are created and managed by Interpreter.
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
+
+#include <stdint.h>
+#include <stdlib.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
+
+#define kOptionalTensor (-1)
+
+// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
+// indices
+typedef struct {
+ int size;
+// gcc 6.1+ have a bug where flexible members aren't properly handled
+// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
+#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
+ __GNUC_MINOR__ >= 1
+ int data[0];
+#else
+ int data[];
+#endif
+} TfLiteIntArray;
+
+// Create a array of a given `size` (uninitialized entries).
+// This returns a pointer, that you must free using TfLiteIntArrayFree().
+TfLiteIntArray* TfLiteIntArrayCreate(int size);
+
+// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise.
+int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b);
+
+// Create a copy of an array passed as `src`.
+// You are expected to free memory with TfLiteIntArrayFree
+TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src);
+
+// Free memory of array `v`.
+void TfLiteIntArrayFree(TfLiteIntArray* v);
+
+// Since we must not depend on any libraries, define a minimal subset of
+// error macros while avoiding names that have pre-conceived meanings like
+// assert and check.
+
+// Check whether value is true, and if not return kTfLiteError from
+// the current function (and report the error string msg).
+#define TF_LITE_ENSURE_MSG(context, value, msg) \
+ do { \
+ if (!(value)) { \
+ (context)->ReportError((context), __FILE__ " " msg); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+// Check whether the value `a` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+#define TF_LITE_ENSURE(context, a) \
+ do { \
+ if (!(a)) { \
+ (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
+ __LINE__, #a); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+#define TF_LITE_ENSURE_STATUS(a) \
+ do { \
+ if ((a) != kTfLiteOk) { \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+// Check whether the value `a == b` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+// `a` and `b` may be evaluated more than once, so no side effects or
+// extremely expensive computations should be done.
+#define TF_LITE_ENSURE_EQ(context, a, b) \
+ do { \
+ if ((a) != (b)) { \
+ (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
+ __LINE__, #a, #b, (a), (b)); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+#define TF_LITE_ENSURE_OK(context, status) \
+ do { \
+ if ((status) != kTfLiteOk) { \
+ return status; \
+ } \
+ } while (0)
+
+// Types supported by tensor
+typedef enum {
+ kTfLiteNoType = 0,
+ kTfLiteFloat32 = 1,
+ kTfLiteInt32 = 2,
+ kTfLiteUInt8 = 3,
+ kTfLiteInt64 = 4,
+ kTfLiteString = 5,
+} TfLiteType;
+
+// Parameters for asymmetric quantization. Quantized values can be converted
+// back to float using:
+// real_value = scale * (quantized_value - zero_point);
+typedef struct {
+ float scale;
+ int32_t zero_point;
+} TfLiteQuantizationParams;
+
+// A union of points that points to memory for a given tensor.
+typedef union {
+ int* i32;
+ float* f;
+ char* raw;
+ const char* raw_const;
+ uint8_t* uint8;
+} TfLitePtrUnion;
+
+// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
+// data (or data externally allocated). kTfLiteArenaRw is arena allocated
+// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
+typedef enum {
+ kTfLiteMemNone = 0,
+ kTfLiteMmapRo,
+ kTfLiteArenaRw,
+ kTfLiteArenaRwPersistent,
+ kTfLiteDynamic,
+} TfLiteAllocationType;
+
+// An tensor in the interpreter system which is a wrapper around a buffer of
+// data including a dimensionality (or NULL if not currently defined).
+typedef struct {
+ // The data type specification for data stored in `data`. This affects
+ // what member of `data` union should be used.
+ TfLiteType type;
+ // A union of data pointers. The appropriate type should be used for a typed
+ // tensor based on `type`.
+ TfLitePtrUnion data;
+ // A pointer to a structure representing the dimensionality interpretation
+ // that the buffer should have. NOTE: the product of elements of `dims`
+ // and the element datatype size should be equal to `bytes` below.
+ TfLiteIntArray* dims;
+ // Quantization information.
+ TfLiteQuantizationParams params;
+ // How memory is mapped
+ // kTfLiteMmapRo: Memory mapped read only.
+ // i.e. weights
+ // kTfLiteArenaRw: Arena allocated read write memory
+ // (i.e. temporaries, outputs).
+ TfLiteAllocationType allocation_type;
+ // The number of bytes required to store the data of this Tensor. I.e.
+ // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
+ // type is kTfLiteFloat32 and dims = {3, 2} then
+ // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
+ size_t bytes;
+
+ // An opaque pointer to a tflite::MMapAllocation
+ const void* allocation;
+
+ // Null-terminated name of this tensor.
+ const char* name;
+} TfLiteTensor;
+
+// Free memory of tensor `t`;
+void TfLiteTensorFree(TfLiteTensor* t);
+
+// Set all of a tensor's fields (and free any previously allocated data).
+void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
+ TfLiteQuantizationParams quantization, char* buffer,
+ size_t size, TfLiteAllocationType allocation_type,
+ const void* allocation, TfLiteTensor* tensor);
+
+// Resize the allocated data of a (dynamic) tensor.
+void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
+
+typedef struct TfLiteContext {
+ // Number of tensors in the context.
+ int tensors_size;
+ // An tensor of tensors in the interpreter context (of length `tensors_size`)
+ TfLiteTensor* tensors;
+
+ // opaque full context ptr (an opaque c++ data structure)
+ void* impl_;
+
+ // Request memory pointer be resized. Updates dimensions on the tensor.
+ // NOTE: ResizeTensor takes ownership of newSize.
+ TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
+ TfLiteIntArray* new_size);
+ // Request that a error be reported with format string msg.
+ void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
+
+ // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If
+ // non-null, the value pointed to by `first_new_tensor_index` will be set to
+ // the index of the first new tensor.
+ TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
+ int* first_new_tensor_index);
+
+ // TODO(ahentz): we should create a more general mechanism for this sort of
+ // library-global objects.
+ void* gemm_context;
+} TfLiteContext;
+
+// A structure representing an instance of a node.
+// This structure only exhibits the inputs, outputs and user defined data, not
+// other features like the type.
+typedef struct {
+ // Inputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* inputs;
+
+ // Outputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* outputs;
+
+ // Temporary tensors uses during the computations. This usually contains no
+ // tensors, but ops are allowed to change that if they need scratch space of
+ // any sort.
+ TfLiteIntArray* temporaries;
+
+ // Opaque data provided by the node implementer through `Registration.init`.
+ void* user_data;
+
+ // Opaque data provided to the node if the node is a builtin.
+ void* builtin_data;
+} TfLiteNode;
+
+typedef struct {
+ // Initializes the op from serialized data.
+ // If a built-in op:
+ // `buffer` is the op's params data (TfLiteLSTMParams*).
+ // `length` is zero.
+ // If custom op:
+ // `buffer` is the op's `custom_options`.
+ // `length` is the size of the buffer.
+ //
+ // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
+ // or an instance of a struct).
+ //
+ // The returned pointer will be stored with the node in the `user_data` field,
+ // accessible within prepare and invoke functions below.
+ // NOTE: if the data is already in the desired format, simply implement this
+ // function to return `nullptr` and implement the free function to be a no-op.
+ void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
+
+ // The pointer `buffer` is the data previously returned by an init invocation.
+ void (*free)(TfLiteContext* context, void* buffer);
+
+ // prepare is called when the inputs this node depends on have been resized.
+ // context->ResizeTensor() can be called to request output tensors to be
+ // resized.
+ //
+ // Returns kTfLiteOk on success.
+ TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
+
+ // Execute the node (should read node->inputs and output to node->outputs).
+ // Returns kTfLiteOk on success.
+ TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
+
+ // Builtin codes. If this kernel refers to a builtin this is the code
+ // of the builtin. This is so we can do marshaling to other frameworks like
+ // NN API. Note, it is the responsibility of the registration binder to
+ // set this properly.
+ int32_t builtin_code;
+} TfLiteRegistration;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/context_test.cc
new file mode 100644
index 0000000000..d0a104f43d
--- /dev/null
+++ b/tensorflow/contrib/lite/context_test.cc
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/context.h"
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+// NOTE: this tests only the TfLiteIntArray part of context.
+// most of context.h is provided in the context of using it with interpreter.h
+// and interpreter.cc, so interpreter_test.cc tests context structures more
+// thoroughly.
+
+TEST(IntArray, TestIntArrayCreate) {
+ TfLiteIntArray* a = TfLiteIntArrayCreate(0);
+ TfLiteIntArray* b = TfLiteIntArrayCreate(3);
+ TfLiteIntArrayFree(a);
+ TfLiteIntArrayFree(b);
+}
+
+TEST(IntArray, TestIntArrayCopy) {
+ TfLiteIntArray* a = TfLiteIntArrayCreate(2);
+ a->data[0] = 22;
+ a->data[1] = 24;
+ TfLiteIntArray* b = TfLiteIntArrayCopy(a);
+ ASSERT_NE(a, b);
+ ASSERT_EQ(a->size, b->size);
+ ASSERT_EQ(a->data[0], b->data[0]);
+ ASSERT_EQ(a->data[1], b->data[1]);
+ TfLiteIntArrayFree(a);
+ TfLiteIntArrayFree(b);
+}
+
+TEST(IntArray, TestIntArrayEqual) {
+ TfLiteIntArray* a = TfLiteIntArrayCreate(1);
+ a->data[0] = 1;
+ TfLiteIntArray* b = TfLiteIntArrayCreate(2);
+ b->data[0] = 5;
+ b->data[1] = 6;
+ TfLiteIntArray* c = TfLiteIntArrayCreate(2);
+ c->data[0] = 5;
+ c->data[1] = 6;
+ TfLiteIntArray* d = TfLiteIntArrayCreate(2);
+ d->data[0] = 6;
+ d->data[1] = 6;
+ ASSERT_FALSE(TfLiteIntArrayEqual(a, b));
+ ASSERT_TRUE(TfLiteIntArrayEqual(b, c));
+ ASSERT_TRUE(TfLiteIntArrayEqual(b, b));
+ ASSERT_FALSE(TfLiteIntArrayEqual(c, d));
+ TfLiteIntArrayFree(a);
+ TfLiteIntArrayFree(b);
+ TfLiteIntArrayFree(c);
+ TfLiteIntArrayFree(d);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/error_reporter.cc
new file mode 100644
index 0000000000..6ba5384a94
--- /dev/null
+++ b/tensorflow/contrib/lite/error_reporter.cc
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include <cstdarg>
+#include <cstdio>
+
+namespace tflite {
+
+ErrorReporter::~ErrorReporter() {}
+
+int ErrorReporter::Report(const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ int code = Report(format, args);
+ va_end(args);
+ return code;
+}
+
+// TODO(aselle): Make the name of ReportError on context the same, so
+// we can use the ensure functions w/o a context and w/ a reporter.
+int ErrorReporter::ReportError(void*, const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ int code = Report(format, args);
+ va_end(args);
+ return code;
+}
+
+int StderrReporter::Report(const char* format, va_list args) {
+ return vfprintf(stderr, format, args);
+}
+
+ErrorReporter* DefaultErrorReporter() {
+ static StderrReporter* error_reporter = new StderrReporter;
+ return error_reporter;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h
new file mode 100644
index 0000000000..637d456ce7
--- /dev/null
+++ b/tensorflow/contrib/lite/error_reporter.h
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
+
+#include <cstdarg>
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+
+// A functor that reports error to supporting system. Invoked similar to
+// printf.
+//
+// Usage:
+// ErrorReporter foo;
+// foo.Report("test %d\n", 5);
+// or
+// va_list args;
+// foo.Report("test %d\n", args); // where args is va_list
+//
+// Sublclass ErrorReporter to provide another reporting destination.
+// For example, if you have a GUI program, you might redirect to a buffer
+// that drives a GUI error log box.
+class ErrorReporter {
+ public:
+ virtual ~ErrorReporter();
+ virtual int Report(const char* format, va_list args) = 0;
+ int Report(const char* format, ...);
+ int ReportError(void*, const char* format, ...);
+};
+
+// An error reporter that simplify writes the message to stderr.
+struct StderrReporter : public ErrorReporter {
+ int Report(const char* format, va_list args) override;
+};
+
+// Return the default error reporter (output to stderr).
+ErrorReporter* DefaultErrorReporter();
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
new file mode 100644
index 0000000000..954e236ac8
--- /dev/null
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -0,0 +1,567 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include <cassert>
+#include <cstdarg>
+#include <cstdint>
+#include <cstring>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/kernels/gemm_support.h"
+#include "tensorflow/contrib/lite/nnapi_delegate.h"
+
+namespace {
+
+// Memory allocation tuning
+constexpr const int kDefaultArenaAlignment = 64;
+constexpr const int kDefaultTensorAlignment = 4;
+// std::vector preallocation tuning.
+constexpr const int kSlotsToReserve = 128;
+
+} // namespace
+
+namespace tflite {
+
+Interpreter::Interpreter(ErrorReporter* error_reporter)
+ : arena_(kDefaultArenaAlignment),
+ persistent_arena_(kDefaultArenaAlignment),
+ error_reporter_(error_reporter ? error_reporter
+ : DefaultErrorReporter()) {
+ context_.impl_ = static_cast<void*>(this);
+ context_.ResizeTensor = ResizeTensor;
+ context_.ReportError = ReportError;
+ context_.AddTensors = AddTensors;
+ context_.tensors = nullptr;
+ context_.tensors_size = 0;
+ context_.gemm_context = nullptr;
+ // Reserve some space for the tensors to avoid excessive resizing.
+ tensors_.reserve(kSlotsToReserve);
+ nodes_and_registration_.reserve(kSlotsToReserve);
+ next_allocate_node_id_ = 0;
+ UseNNAPI(false);
+}
+
+Interpreter::~Interpreter() {
+ for (auto& nodeAndReg : nodes_and_registration_) {
+ TfLiteNode& node = nodeAndReg.first;
+ TfLiteIntArrayFree(node.inputs);
+ TfLiteIntArrayFree(node.outputs);
+ TfLiteIntArrayFree(node.temporaries);
+ if (node.builtin_data) free(node.builtin_data);
+ OpFree(nodeAndReg.second, node.user_data);
+ node.builtin_data = nullptr;
+ }
+
+ for (int i = 0; i < context_.tensors_size; i++) {
+ TfLiteTensorFree(&context_.tensors[i]);
+ }
+}
+
+TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
+ TF_LITE_ENSURE_OK(&context_,
+ CheckTensorIndices("inputs", inputs.data(), inputs.size()));
+ inputs_ = std::move(inputs);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
+ TF_LITE_ENSURE_OK(
+ &context_, CheckTensorIndices("outputs", outputs.data(), outputs.size()));
+ outputs_ = std::move(outputs);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::CheckTensorIndices(const char* label,
+ const int* indices, int length) {
+ // Making sure kOptionalTensor is not re-defined to something other than -1.
+ static_assert(kOptionalTensor == -1, "kOptionalTensor should be defined -1");
+
+ for (int i = 0; i < length; i++) {
+ int index = indices[i];
+ if (index < kOptionalTensor || index >= context_.tensors_size) {
+ ReportError(&context_, "Invalid tensor index %d in %s\n", index, label);
+ consistent_ = false;
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
+ int dims_size, size_t* bytes) {
+ // TODO(aselle): Check for overflow here using overflow.h in TensorFlow
+ // MultiplyWithoutOverflow.
+ TF_LITE_ENSURE(&context_, bytes != nullptr);
+ size_t count = 1;
+ for (int k = 0; k < dims_size; k++) count *= dims[k];
+ switch (type) {
+ case kTfLiteFloat32:
+ *bytes = sizeof(float) * count;
+ break;
+ case kTfLiteInt32:
+ *bytes = sizeof(int32_t) * count;
+ break;
+ case kTfLiteUInt8:
+ *bytes = sizeof(uint8_t) * count;
+ break;
+ case kTfLiteInt64:
+ *bytes = sizeof(int64_t) * count;
+ break;
+ default:
+ ReportError(&context_,
+ "Only float32, int32, int64, uint8 supported currently.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::AllocateTensorsWhoseSizesAreKnown() {
+ if (!consistent_) {
+ ReportError(&context_, "AllocateTensors() called on inconsistent model.");
+ return kTfLiteError;
+ }
+ if (next_allocate_node_id_ == nodes_and_registration_.size() && invokable_) {
+ return kTfLiteOk;
+ }
+ allocs_and_refcounts_.resize(context_.tensors_size);
+
+ int new_next_allocate_node_id = next_allocate_node_id_;
+ invokable_ = false;
+
+ // Allocate graph input nodes.
+ if (next_allocate_node_id_ == 0) {
+ for (int i = 0; i < inputs_.size(); ++i) {
+ int tensor_index = inputs_[i];
+ if (tensor_index == kOptionalTensor) {
+ continue;
+ }
+ TfLiteTensor& tensor = context_.tensors[tensor_index];
+ if (tensor.allocation_type == kTfLiteArenaRw) {
+ TF_LITE_ENSURE_OK(
+ &context_,
+ arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes,
+ &allocs_and_refcounts_[tensor_index].alloc));
+ }
+ }
+ // Add 1 to output tensors, so they will not get overwritten.
+ for (int i = 0; i < outputs_.size(); ++i) {
+ allocs_and_refcounts_[outputs_[i]].count++;
+ }
+ }
+
+ // Count references to node input tensors, and resize node-referenced tensors
+ // until we encounter a node that has a dynamic output tensor.
+ for (int k = next_allocate_node_id_; k < nodes_and_registration_.size();
+ k++) {
+ new_next_allocate_node_id++;
+ TfLiteNode& node = nodes_and_registration_[k].first;
+ const TfLiteRegistration& registration = nodes_and_registration_[k].second;
+ if (OpPrepare(registration, &node) == kTfLiteError) {
+ return kTfLiteError;
+ }
+
+ TfLiteIntArray* node_inputs = node.inputs;
+ for (int i = 0; i < node_inputs->size; ++i) {
+ int tensor_index = node_inputs->data[i];
+ if (tensor_index != kOptionalTensor) {
+ allocs_and_refcounts_[node_inputs->data[i]].count++;
+ }
+ }
+
+ // Discontinue if the node has dynamic outputs.
+ bool has_unallocated_dynamic_tensor = false;
+ TfLiteIntArray* node_outputs = node.outputs;
+ for (int i = 0; i < node_outputs->size; ++i) {
+ TfLiteTensor& tensor = context_.tensors[node_outputs->data[i]];
+ if (tensor.allocation_type == kTfLiteDynamic) {
+ has_unallocated_dynamic_tensor = true;
+ break;
+ }
+ }
+ if (has_unallocated_dynamic_tensor) {
+ break;
+ }
+ }
+
+ // Allocate graph persistent outputs, e.g. RNN cell states, etc.
+ for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) {
+ TfLiteNode& node = nodes_and_registration_[k].first;
+
+ // Go through output tensors and allocate the persistent ones first.
+ TfLiteIntArray* node_outputs = node.outputs;
+ for (int i = 0; i < node_outputs->size; ++i) {
+ int tensor_index = node_outputs->data[i];
+ TfLiteTensor& tensor = context_.tensors[tensor_index];
+ if (tensor.allocation_type == kTfLiteArenaRwPersistent) {
+ TF_LITE_ENSURE_OK(&context_,
+ persistent_arena_.Allocate(
+ &context_, kDefaultTensorAlignment, tensor.bytes,
+ &allocs_and_refcounts_[tensor_index].alloc));
+ }
+ }
+ }
+
+ // Go through the graph in execution order.
+ for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) {
+ TfLiteNode& node = nodes_and_registration_[k].first;
+
+ // First allocate output tensors.
+ TfLiteIntArray* node_outputs = node.outputs;
+ for (int i = 0; i < node_outputs->size; ++i) {
+ int tensor_index = node_outputs->data[i];
+ TfLiteTensor& tensor = context_.tensors[tensor_index];
+ if (tensor.allocation_type == kTfLiteArenaRw) {
+ TF_LITE_ENSURE_OK(
+ &context_,
+ arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes,
+ &allocs_and_refcounts_[tensor_index].alloc));
+ }
+ }
+ // Then the temporaries, in two passes. First allocate them all, them
+ // deallocate them.
+ TfLiteIntArray* node_temporaries = node.temporaries;
+ for (int i = 0; i < node_temporaries->size; ++i) {
+ int tensor_index = node_temporaries->data[i];
+ TfLiteTensor& tensor = context_.tensors[tensor_index];
+ if (tensor.allocation_type == kTfLiteArenaRw) {
+ TF_LITE_ENSURE_OK(
+ &context_,
+ arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes,
+ &allocs_and_refcounts_[tensor_index].alloc));
+ }
+ }
+ for (int i = 0; i < node_temporaries->size; ++i) {
+ int tensor_index = node_temporaries->data[i];
+ TfLiteTensor& tensor = context_.tensors[tensor_index];
+ allocs_and_refcounts_[tensor_index].count--;
+ if (tensor.allocation_type == kTfLiteArenaRw &&
+ allocs_and_refcounts_[tensor_index].count == 0) {
+ TF_LITE_ENSURE_OK(
+ &context_,
+ arena_.Deallocate(&context_,
+ allocs_and_refcounts_[tensor_index].alloc));
+ }
+ }
+
+ // Then process the node's inputs.
+ TfLiteIntArray* node_inputs = node.inputs;
+ for (int i = 0; i < node_inputs->size; ++i) {
+ int tensor_index = node_inputs->data[i];
+ if (tensor_index == kOptionalTensor) {
+ continue;
+ }
+ TfLiteTensor& tensor = context_.tensors[tensor_index];
+
+ // Decrease reference count and deallocate if not needed anymore.
+ allocs_and_refcounts_[tensor_index].count--;
+ if (tensor.allocation_type == kTfLiteArenaRw &&
+ allocs_and_refcounts_[tensor_index].count == 0) {
+ TF_LITE_ENSURE_OK(
+ &context_,
+ arena_.Deallocate(&context_,
+ allocs_and_refcounts_[tensor_index].alloc));
+ }
+ }
+ }
+
+ // Resize the buffer and commit the arena.
+ TF_LITE_ENSURE_OK(&context_, arena_.Commit(&context_));
+ TF_LITE_ENSURE_OK(&context_, persistent_arena_.Commit(&context_));
+
+ // Rewire the tensors to use the underlying arena buffer.
+ for (int i = 0; i < context_.tensors_size; ++i) {
+ TfLiteTensor& tensor = context_.tensors[i];
+ if (tensor.allocation_type == kTfLiteArenaRw) {
+ TF_LITE_ENSURE_OK(
+ &context_,
+ arena_.ResolveAlloc(&context_, allocs_and_refcounts_[i].alloc,
+ &tensor.data.raw));
+ }
+ if (tensor.allocation_type == kTfLiteArenaRwPersistent) {
+ TF_LITE_ENSURE_OK(
+ &context_,
+ persistent_arena_.ResolveAlloc(
+ &context_, allocs_and_refcounts_[i].alloc, &tensor.data.raw));
+ }
+ }
+
+ invokable_ = true;
+ next_allocate_node_id_ = new_next_allocate_node_id;
+ return kTfLiteOk;
+}
+
+namespace {
+TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector<int>& x) {
+ TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size());
+ for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i];
+ return lite;
+}
+} // namespace
+
+TfLiteStatus Interpreter::AllocateTensors() {
+ next_allocate_node_id_ = 0;
+ TF_LITE_ENSURE_OK(&context_, arena_.Clear());
+ TF_LITE_ENSURE_OK(&context_, persistent_arena_.Clear());
+ allocs_and_refcounts_.clear();
+ return AllocateTensorsWhoseSizesAreKnown();
+}
+
+TfLiteStatus Interpreter::AddNodeWithParameters(
+ const std::vector<int>& inputs, const std::vector<int>& outputs,
+ const char* init_data, size_t init_data_size, void* builtin_data,
+ const TfLiteRegistration* registration, int* node_index) {
+ invokable_ = false;
+
+ std::unique_ptr<void, decltype(free)*> builtin_data_deleter(builtin_data,
+ free);
+
+ TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("node inputs", inputs.data(),
+ inputs.size()));
+ TF_LITE_ENSURE_OK(
+ &context_,
+ CheckTensorIndices("node outputs", outputs.data(), outputs.size()));
+
+ if (node_index) *node_index = nodes_and_registration_.size();
+ nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
+ auto& node_and_reg = nodes_and_registration_.back();
+ TfLiteNode& node = node_and_reg.first;
+ if (node.inputs) TfLiteIntArrayFree(node.inputs);
+ if (node.outputs) TfLiteIntArrayFree(node.outputs);
+ if (node.temporaries) TfLiteIntArrayFree(node.temporaries);
+
+ // NOTE, here we are not using move semantics yet, since our internal
+ // representation isn't std::vector, but in the future we would like to avoid
+ // copies, so we want the interface to take r-value references now.
+ node.inputs = convertVectorToTfLiteIntArray(inputs);
+ node.outputs = convertVectorToTfLiteIntArray(outputs);
+ node.temporaries = TfLiteIntArrayCreate(0);
+ if (init_data) {
+ node.user_data = OpInit(*registration, init_data, init_data_size);
+ } else {
+ node.user_data =
+ OpInit(*registration,
+ reinterpret_cast<const char*>(builtin_data_deleter.get()), 0);
+ }
+ node.builtin_data = builtin_data_deleter.release();
+ node_and_reg.second = *registration;
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
+ const std::vector<int>& dims) {
+ // TODO(aselle): All bounds checks can be implemented as one-sided bounds
+ // checks by casting to unsigned for efficiency. Profile before doing this.
+
+ TF_LITE_ENSURE(&context_,
+ tensor_index < context_.tensors_size && tensor_index >= 0);
+ invokable_ = false;
+ TfLiteIntArray* dims_lite = convertVectorToTfLiteIntArray(dims);
+ return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite);
+}
+
+TfLiteStatus Interpreter::Invoke() {
+ if (!consistent_) {
+ ReportError(&context_, "Invoke called on model that is not consistent.");
+ return kTfLiteError;
+ }
+ if (!invokable_) {
+ ReportError(&context_, "Invoke called on model that is not ready.");
+ return kTfLiteError;
+ }
+
+ TfLiteStatus status = kTfLiteOk;
+ if (nnapi_delegate_) {
+ if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) {
+ return kTfLiteError;
+ }
+ if (next_allocate_node_id_ == nodes_and_registration_.size()) {
+ TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this));
+ return kTfLiteOk;
+ } else {
+ // TODO(aselle): In the future, we would like this to be an
+ // automatic tflite CPU fallback.
+ ReportError(&context_,
+ "NNAPI was requested, but dependent sized tensors "
+ "being used.\n");
+ return kTfLiteError;
+ }
+ }
+
+ for (int i = 0; i < nodes_and_registration_.size(); i++) {
+ // Ensure we have allocated up to this node. The point of this is to
+ // allocate as much as possible before running any evaluation, but
+ // dynamic shapes can prevent this from being possible.
+ if (i >= next_allocate_node_id_) {
+ if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) {
+ return kTfLiteError;
+ }
+ }
+ TfLiteNode& node = nodes_and_registration_[i].first;
+ const TfLiteRegistration& registration = nodes_and_registration_[i].second;
+ if (OpInvoke(registration, &node) == kTfLiteError) {
+ status = kTfLiteError;
+ }
+ }
+ return status;
+}
+
+TfLiteStatus Interpreter::ResizeTensor(TfLiteContext* context,
+ TfLiteTensor* tensor,
+ TfLiteIntArray* new_size) {
+ // Note here that context->impl_ is recovering the this pointer for an
+ // instance of Interpreter to call into the member function ResizeTensorImpl
+ // (this function is static).
+ return static_cast<Interpreter*>(context->impl_)
+ ->ResizeTensorImpl(tensor, new_size);
+}
+
+void Interpreter::ReportErrorImpl(const char* format, va_list args) {
+ error_reporter_->Report(format, args);
+}
+
+void Interpreter::ReportError(TfLiteContext* context, const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ auto* f = static_cast<Interpreter*>(context->impl_);
+ // Note here that context->impl_ is recovering the this pointer for an
+ // instance of Interpreter to call into the member function ReportErrorImpl
+ // (this function is static).
+ f->ReportErrorImpl(format, args);
+ va_end(args);
+}
+
+TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
+ int* first_new_tensor_index) {
+ int base_index = tensors_.size();
+ if (first_new_tensor_index) *first_new_tensor_index = base_index;
+ tensors_.resize(tensors_.size() + tensors_to_add);
+ for (int i = base_index; i < tensors_.size(); i++) {
+ memset(&tensors_[i], 0, sizeof(tensors_[i]));
+ }
+ context_.tensors = tensors_.data();
+ context_.tensors_size = tensors_.size();
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::AddTensors(TfLiteContext* context, int tensors_to_add,
+ int* first_new_tensor_index) {
+ // Note here that context->impl_ is recovering the this pointer for an
+ // instance of Interpreter to call into the member function AddTensors
+ // (this function is static).
+ return static_cast<Interpreter*>(context->impl_)
+ ->AddTensors(tensors_to_add, first_new_tensor_index);
+}
+
+TfLiteStatus Interpreter::SetTensorParametersReadOnly(
+ int tensor_index, TfLiteType type, const char* name,
+ const std::vector<int>& dims, TfLiteQuantizationParams quantization,
+ const char* buffer, size_t bytes, const Allocation* allocation) {
+ TF_LITE_ENSURE(&context_,
+ tensor_index < context_.tensors_size && tensor_index >= 0);
+ // For most tensors we know exactly how much memory is necessary so we can
+ // ensure the buffer is large enough. However, we need to skip string tensors
+ // because their sizes change with the contents of the individual strings.
+ if (type != kTfLiteString) {
+ size_t required_bytes;
+ TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
+ &required_bytes));
+ TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes);
+ }
+ invokable_ = false;
+ TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims),
+ quantization, const_cast<char*>(buffer), bytes,
+ kTfLiteMmapRo, allocation, &context_.tensors[tensor_index]);
+ return kTfLiteOk;
+}
+
+// Set description of inputs/outputs/data/fptrs for node `node_index`.
+// This variant assumes an external buffer has been allocated of size
+// bytes. The lifetime of buffer must be ensured to be greater or equal
+// to Interpreter.
+TfLiteStatus Interpreter::SetTensorParametersReadWrite(
+ int tensor_index, TfLiteType type, const char* name,
+ const std::vector<int>& dims, TfLiteQuantizationParams quantization) {
+ invokable_ = false;
+ TF_LITE_ENSURE(&context_,
+ tensor_index < context_.tensors_size && tensor_index >= 0);
+ size_t required_bytes = 0;
+ if (type != kTfLiteString) {
+ // These types will be allocated in our arena so we need to record how
+ // many bytes we will need based on the dimensions. String tensors are
+ // allocated dynamically and we can't know ahead of time how much space
+ // they will require.
+ TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
+ &required_bytes));
+ }
+ TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims),
+ quantization,
+ /*buffer=*/nullptr, required_bytes,
+ type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,
+ nullptr, &context_.tensors[tensor_index]);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor,
+ TfLiteIntArray* new_size) {
+ // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too.
+ if (tensor->allocation_type == kTfLiteArenaRw ||
+ tensor->allocation_type == kTfLiteDynamic) {
+ if (tensor->type != kTfLiteString) {
+ size_t bytesRequired;
+ TfLiteStatus status = BytesRequired(tensor->type, new_size->data,
+ new_size->size, &bytesRequired);
+ if (status != kTfLiteOk) {
+ TfLiteIntArrayFree(new_size);
+ return kTfLiteError;
+ }
+ tensor->bytes = bytesRequired;
+ }
+ if (tensor->dims) TfLiteIntArrayFree(tensor->dims);
+ tensor->dims = new_size;
+
+ if (tensor->allocation_type != kTfLiteDynamic) {
+ tensor->data.raw = nullptr;
+ }
+ } else {
+ // kTfLiteMmapRo tensors are stored in the flatbuffer and are therefore
+ // of fixed size.
+ TfLiteIntArrayFree(new_size);
+ ReportError(&context_, "Attempting to resize a fixed-size tensor.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+void Interpreter::UseNNAPI(bool enable) {
+ // TODO(aselle): This is a workaround for finding if NNAPI exists.
+ // We also need to make sure getLibraryHandle() is renamed to be NNAPI
+ // prefixed.
+ if (!NNAPIExists()) enable = false;
+ if (!enable) {
+ nnapi_delegate_.reset();
+ } else if (!nnapi_delegate_) {
+ nnapi_delegate_.reset(new NNAPIDelegate);
+ }
+}
+
+void Interpreter::SetNumThreads(int num_threads) {
+ // TODO(ahentz): this forces us to link against gemmlowp even when the ops
+ // don't use it. We should implement some dynamic mechanism for this sort of
+ // library-specific initialization.
+ tflite::gemm_support::SetMaxNumThreads(&context_, num_threads);
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
new file mode 100644
index 0000000000..8bf60e91f7
--- /dev/null
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -0,0 +1,376 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Main abstraction controlling the tflite interpreter.
+// See context.h for the API for defining operations (TfLiteRegistration).
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
+
+#include <cstdio>
+#include <cstdlib>
+#include <vector>
+#include "tensorflow/contrib/lite/allocation.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/simple_memory_arena.h"
+#include "tensorflow/core/platform/platform.h"
+
+namespace tflite {
+
+// Map statically from a c++ type to a TfLiteType (used below for safe casts).
+template <class T>
+constexpr TfLiteType typeToTfLiteType() {
+ return kTfLiteNoType;
+}
+template <>
+constexpr TfLiteType typeToTfLiteType<int>() {
+ return kTfLiteInt32;
+}
+template <>
+constexpr TfLiteType typeToTfLiteType<int64_t>() {
+ return kTfLiteInt64;
+}
+template <>
+constexpr TfLiteType typeToTfLiteType<float>() {
+ return kTfLiteFloat32;
+}
+template <>
+constexpr TfLiteType typeToTfLiteType<unsigned char>() {
+ return kTfLiteUInt8;
+}
+
+struct ArenaAllocRefCount {
+ ArenaAllocRefCount() : alloc(), count(0) {}
+
+ ArenaAlloc alloc;
+ int count;
+};
+
+// Forward declare since NNAPIDelegate uses Interpreter.
+class NNAPIDelegate;
+
+// An interpreter for a graph of nodes that input and output from tensors.
+// Each node of the graph processes a set of input tensors and produces a
+// set of output Tensors. All inputs/output tensors are referenced by index.
+//
+// Usage:
+//
+// -- Create basic model
+// Interpreter foo(2, 1);
+// foo.SetTensorParametersReadWrite(0, ...);
+// foo.SetTensorParametersReadOnly(1, ...);
+// foo.SetNodeParameters(0, ...)
+//
+// -- Resize input array to 1 length.
+// foo.ResizeInputTensor(0, 1);
+// foo.AllocateTensors();
+// -- Install array data
+// foo.typed_tensor<float>(0)[0] = 3;
+// foo.Invoke();
+// foo.typed_tensor<float>(0)[0] = 4;
+// foo.Invoke();
+// -- Resize input array and set data.
+// foo.ResizeInputTensor(0, 2);
+// foo.AllocateTensors();
+// foo.typed_tensor<float>(0)[0] = 4;
+// foo.typed_tensor<float>(0)[1] = 8;
+// foo.Invoke();
+//
+
+class Interpreter {
+ public:
+ // Instantiate an interpreter. All errors associated with reading and
+ // processing this model will be forwarded to the error_reporter object.
+ //
+ // Note, if error_reporter is nullptr, then a default StderrReporter is
+ // used.
+ explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter());
+
+ ~Interpreter();
+
+ Interpreter(const Interpreter&) = delete;
+ Interpreter& operator=(const Interpreter&) = delete;
+
+ // Functions to build interpreter
+
+ // Provide a list of tensor indexes that are inputs to the model.
+ // Each index is bound check and this modifies the consistent_ flag of the
+ // interpreter.
+ TfLiteStatus SetInputs(std::vector<int> inputs);
+
+ // Provide a list of tensor indexes that are outputs to the model
+ // Each index is bound check and this modifies the consistent_ flag of the
+ // interpreter.
+ TfLiteStatus SetOutputs(std::vector<int> outputs);
+
+ // Adds a node with the given parameters and returns the index of the new
+ // node in `node_index` (optionally). Interpreter will take ownership of
+ // `builtin_data` and destroy it with `delete`. Ownership of 'init_data'
+ // remains with the caller.
+ TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs,
+ const std::vector<int>& outputs,
+ const char* init_data,
+ size_t init_data_size, void* builtin_data,
+ const TfLiteRegistration* registration,
+ int* node_index = nullptr);
+
+ // Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
+ // The value pointed to by `first_new_tensor_index` will be set to the
+ // index of the first new tensor if `first_new_tensor_index` is non-null.
+ TfLiteStatus AddTensors(int tensors_to_add,
+ int* first_new_tensor_index = nullptr);
+
+ // Set description of inputs/outputs/data/fptrs for node `node_index`.
+ // This variant assumes an external buffer has been allocated of size
+ // bytes. The lifetime of buffer must be ensured to be greater or equal
+ // to Interpreter.
+ TfLiteStatus SetTensorParametersReadOnly(
+ int tensor_index, TfLiteType type, const char* name,
+ const std::vector<int>& dims, TfLiteQuantizationParams quantization,
+ const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
+
+ // Set description of inputs/outputs/data/fptrs for node `node_index`.
+ // This variant assumes an external buffer has been allocated of size
+ // bytes. The lifetime of buffer must be ensured to be greater or equal
+ // to Interpreter.
+ TfLiteStatus SetTensorParametersReadWrite(
+ int tensor_index, TfLiteType type, const char* name,
+ const std::vector<int>& dims, TfLiteQuantizationParams quantization);
+
+ // Functions to access tensor data
+
+ // Read only access to list of inputs.
+ const std::vector<int>& inputs() const { return inputs_; }
+
+ // Return the name of a given input. The given index must be between 0 and
+ // inputs().size().
+ const char* GetInputName(int index) const {
+ return context_.tensors[inputs_[index]].name;
+ }
+
+ // Read only access to list of outputs.
+ const std::vector<int>& outputs() const { return outputs_; }
+
+ // Return the name of a given output. The given index must be between 0 and
+ // outputs().size().
+ const char* GetOutputName(int index) const {
+ return context_.tensors[outputs_[index]].name;
+ }
+
+ // Return the number of tensors in the model.
+ int tensors_size() const { return context_.tensors_size; }
+
+ // Return the number of ops in the model.
+ int nodes_size() const { return nodes_and_registration_.size(); }
+
+ // Get a tensor data structure.
+ // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
+ // read/write access to structure
+ TfLiteTensor* tensor(int tensor_index) {
+ if (tensor_index >= context_.tensors_size || tensor_index < 0)
+ return nullptr;
+ return &context_.tensors[tensor_index];
+ }
+
+ // Get a pointer to an operation and registration data structure if in bounds.
+ // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
+ // read/write access to structure
+ const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
+ int node_index) {
+ if (node_index >= nodes_and_registration_.size() || node_index < 0)
+ return nullptr;
+ return &nodes_and_registration_[node_index];
+ }
+
+ // Perform a checked cast to the appropriate tensor type.
+ template <class T>
+ T* typed_tensor(int tensor_index) {
+ if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
+ if (tensor_ptr->type == typeToTfLiteType<T>()) {
+ return reinterpret_cast<T*>(tensor_ptr->data.raw);
+ }
+ }
+ return nullptr;
+ }
+
+ // Return a pointer into the data of a given input tensor. The given index
+ // must be between 0 and inputs().size().
+ template <class T>
+ T* typed_input_tensor(int index) {
+ return typed_tensor<T>(inputs_[index]);
+ }
+
+ // Return a pointer into the data of a given output tensor. The given index
+ // must be between 0 and outputs().size().
+ template <class T>
+ T* typed_output_tensor(int index) {
+ return typed_tensor<T>(outputs_[index]);
+ }
+
+ // Change the dimensionality of a given tensor. Note, this is only acceptable
+ // for tensor indices that are inputs.
+ // Returns status of failure or success.
+ // TODO(aselle): Consider implementing ArraySlice equivalent to make this
+ // more adept at accepting data without an extra copy. Use absl::ArraySlice
+ // if our partners determine that dependency is acceptable.
+ TfLiteStatus ResizeInputTensor(int tensor_index,
+ const std::vector<int>& dims);
+
+ // Update allocations for all tensors. This will redim dependent tensors using
+ // the input tensor dimensionality as given. This is relatively expensive.
+ // If you know that your sizes are not changing, you need not call this.
+
+ // Returns status of success or failure.
+ // TODO(aselle): Madde
+ TfLiteStatus AllocateTensors();
+
+ // Invoke the interpreter (run the whole graph in dependency order).
+ //
+ // NOTE: It is possible that the interpreter is not in a ready state
+ // to evaluate (i.e. if a ResizeTensor() has been performed without an
+ // AllocateTensors().
+ // Returns status of success or failure.
+ TfLiteStatus Invoke();
+
+ // Enable or disable the NN API (true to enable)
+ void UseNNAPI(bool enable);
+
+ // Set the number of threads available to the interpreter.
+ void SetNumThreads(int num_threads);
+
+ private:
+ // Give 'op_reg' a chance to initialize itself using the contents of
+ // 'buffer'.
+ void* OpInit(const TfLiteRegistration& op_reg, const char* buffer,
+ size_t length) {
+ if (op_reg.init == nullptr) return nullptr;
+ return op_reg.init(&context_, buffer, length);
+ }
+
+ // Let 'op_reg' release any memory it might have allocated via 'OpInit'.
+ void OpFree(const TfLiteRegistration& op_reg, void* buffer) {
+ if (op_reg.free == nullptr) return;
+ if (buffer) {
+ op_reg.free(&context_, buffer);
+ }
+ }
+
+ // Prepare the given 'node' for execution.
+ TfLiteStatus OpPrepare(const TfLiteRegistration& op_reg, TfLiteNode* node) {
+ if (op_reg.prepare == nullptr) return kTfLiteOk;
+ return op_reg.prepare(&context_, node);
+ }
+
+ // Invoke the operator represented by 'node'.
+ TfLiteStatus OpInvoke(const TfLiteRegistration& op_reg, TfLiteNode* node) {
+ if (op_reg.invoke == nullptr) return kTfLiteError;
+ return op_reg.invoke(&context_, node);
+ }
+
+ // Allocate tensors whose sizes are known in order of nodes. Discontinue when
+ // we encounter a node that has a dynamic output tensor.
+ TfLiteStatus AllocateTensorsWhoseSizesAreKnown();
+
+ // Tensors needed by the interpreter. Use `AddTensors` to add more blank
+ // tensor entries. Note, `tensors_.data()` needs to be synchronized to the
+ // `context_` whenever this std::vector is reallocated. Currently this
+ // only happens in `AddTensors()`.
+ std::vector<TfLiteTensor> tensors_;
+
+ // Check if an array of tensor indices are valid with respect to the Tensor
+ // array.
+ // NOTE: this changes consistent_ to be false if indices are out of bounds.
+ TfLiteStatus CheckTensorIndices(const char* label, const int* indices,
+ int length);
+
+ // Compute the number of bytes required to represent a tensor with dimensions
+ // specified by the array dims (of length dims_size). Returns the status code
+ // and bytes.
+ TfLiteStatus BytesRequired(TfLiteType type, const int* dims, int dims_size,
+ size_t* bytes);
+
+ // Request an tensor be resized implementation.
+ TfLiteStatus ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size);
+
+ // Report a detailed error string (will be printed to stderr).
+ // TODO(aselle): allow user of class to provide alternative destinations.
+ void ReportErrorImpl(const char* format, va_list args);
+
+ // Entry point for C node plugin API to request an tensor be resized.
+ static TfLiteStatus ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor,
+ TfLiteIntArray* new_size);
+ // Entry point for C node plugin API to report an error.
+ static void ReportError(TfLiteContext* context, const char* format, ...);
+
+ // Entry point for C node plugin API to add new tensors.
+ static TfLiteStatus AddTensors(TfLiteContext* context, int tensors_to_add,
+ int* first_new_tensor_index);
+
+ // A pure C data structure used to communicate with the pure C plugin
+ // interface. To avoid copying tensor metadata, this is also the definitive
+ // structure to store tensors.
+ TfLiteContext context_;
+
+ // Node inputs/outputs are stored in TfLiteNode and TfLiteRegistration stores
+ // function pointers to actual implementation.
+ std::vector<std::pair<TfLiteNode, TfLiteRegistration>>
+ nodes_and_registration_;
+
+ // Raw memory buffer that is allocated for all temporary and graph outputs.
+ // that are declared kTfLiteArenaRw.
+ SimpleMemoryArena arena_;
+
+ // Raw memory buffer that is allocated for persistent tensors that are
+ // declared as kTfLiteArenaRwPersistent.
+ SimpleMemoryArena persistent_arena_;
+
+ // Stores allocation and reference counts of all tensors.
+ std::vector<ArenaAllocRefCount> allocs_and_refcounts_;
+
+ // Whether the model is consistent. That is to say if the inputs and outputs
+ // of every node and the global inputs and outputs are valid indexes into
+ // the tensor array.
+ bool consistent_ = true;
+
+ // Whether the model is safe to invoke (if any errors occurred this
+ // will be false).
+ bool invokable_ = false;
+
+ // Array of indices representing the tensors that are inputs to the
+ // interpreter.
+ std::vector<int> inputs_;
+
+ // Array of indices representing the tensors that are outputs to the
+ // interpreter.
+ std::vector<int> outputs_;
+
+ // The error reporter delegate that tflite will forward queries errors to.
+ ErrorReporter* error_reporter_;
+
+ // Next node to allocate output tensors.
+ // During Invoke(), Interpreter will allocate input tensors first, which are
+ // known to be fixed size. Then it will allocate outputs from nodes as many
+ // as possible. When there is a node that produces dynamic sized tensor.
+ // Intepreter will stop allocating tensors, set the value of next allocate
+ // node id, and execute the node to generate the output tensor before continue
+ // to allocate successors. This process repeats until all nodes are executed.
+ // NOTE: this relies on the order of nodes that is in topological order.
+ int next_allocate_node_id_;
+
+ // Whether to delegate to NN API
+ std::unique_ptr<NNAPIDelegate> nnapi_delegate_;
+};
+
+} // namespace tflite
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
new file mode 100644
index 0000000000..edff210943
--- /dev/null
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -0,0 +1,526 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace {
+
+// Make an interpreter that has no tensors and no nodes
+TEST(BasicInterpreter, ZeroInterpreter) {
+ Interpreter interpreter;
+ interpreter.SetInputs({});
+ interpreter.SetOutputs({});
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+}
+
+// Test various error conditions.
+TEST(BasicInterpreter, InvokeInvalidModel) {
+ Interpreter interpreter;
+ ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+}
+
+// Test size accesser functions.
+TEST(BasicInterpreter, TestSizeFunctions) {
+ Interpreter interpreter;
+ int base_index;
+ ASSERT_EQ(interpreter.nodes_size(), 0);
+ ASSERT_EQ(interpreter.tensors_size(), 0);
+ ASSERT_EQ(interpreter.AddTensors(2, &base_index), kTfLiteOk);
+ ASSERT_EQ(interpreter.tensors_size(), 2);
+ ASSERT_EQ(base_index, 0);
+ ASSERT_EQ(interpreter.AddTensors(3, &base_index), kTfLiteOk);
+ ASSERT_EQ(interpreter.tensors_size(), 5);
+ ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
+ ASSERT_EQ(interpreter.tensors_size(), 6);
+ ASSERT_EQ(base_index, 2);
+}
+
+// Test if invalid indices make a model inconsistent (and conversely if
+// valid indices keep a model consistent).
+TEST(BasicInterpreter, InconsistentModel) {
+ // Invalid inputs
+ {
+ Interpreter interpreter;
+ ASSERT_NE(interpreter.SetInputs({5}), kTfLiteOk);
+ ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
+ ASSERT_EQ(interpreter.inputs(), std::vector<int>());
+ }
+ // Invalid outputs
+ {
+ Interpreter interpreter;
+ ASSERT_NE(interpreter.SetOutputs({5}), kTfLiteOk);
+ ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
+ ASSERT_EQ(interpreter.outputs(), std::vector<int>());
+ }
+ // Invalid node inputs
+ {
+ Interpreter interpreter;
+ TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
+ ASSERT_NE(interpreter.AddNodeWithParameters({3}, {0}, nullptr, 0, nullptr,
+ &registration),
+ kTfLiteOk);
+ ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
+ }
+ // Valid inputs and outputs and a node with valid inputs and outputs
+ {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
+ TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
+ ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
+ &registration),
+ kTfLiteOk);
+ }
+}
+
+// Make an interpreter that has one tensor but no ops
+TEST(BasicInterpreter, CheckAllocate) {
+ struct {
+ TfLiteType type;
+ size_t size;
+ } cases[] = {
+ {kTfLiteFloat32, sizeof(float)},
+ {kTfLiteInt32, sizeof(int32_t)},
+ {kTfLiteUInt8, sizeof(uint8_t)},
+ {kTfLiteInt64, sizeof(int64_t)},
+ };
+
+ for (auto test : cases) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
+ interpreter.SetInputs({0, 1});
+ interpreter.SetOutputs({});
+ TfLiteQuantizationParams quant;
+
+ interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant);
+ interpreter.SetTensorParametersReadWrite(1, test.type, "", {4}, quant);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.tensor(0)->bytes, 3 * test.size);
+ ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr);
+ ASSERT_EQ(interpreter.tensor(1)->bytes, 4 * test.size);
+ ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr);
+ }
+}
+
+TEST(BasicInterpreter, CheckResize) {
+ const float floats[] = {-3., -4.};
+ const int32_t int32s[] = {-3, -4};
+ const uint8_t uint8s[] = {3, 4};
+ const int64_t int64s[] = {6, -7};
+
+ struct {
+ TfLiteType type;
+ size_t size;
+ const char* array;
+ } cases[] = {
+ {kTfLiteFloat32, sizeof(float), reinterpret_cast<const char*>(floats)},
+ {kTfLiteInt32, sizeof(int32_t), reinterpret_cast<const char*>(int32s)},
+ {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast<const char*>(uint8s)},
+ {kTfLiteInt64, sizeof(int64_t), reinterpret_cast<const char*>(int64s)},
+ };
+
+ for (auto test : cases) {
+ Interpreter interpreter;
+
+ ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
+ interpreter.SetInputs({0, 1});
+ interpreter.SetOutputs({});
+ TfLiteQuantizationParams quant;
+
+ ASSERT_EQ(
+ interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.SetTensorParametersReadOnly(
+ 1, test.type, "", {2}, quant, test.array, 2 * test.size),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.ResizeInputTensor(0, {1, 2}), kTfLiteOk);
+ // Resizing a mmapped tensor is not allowed and should produce error.
+ ASSERT_NE(interpreter.ResizeInputTensor(1, {3}), kTfLiteOk);
+ // Set the tensor to be mmapped but with a buffer size that is insufficient
+ // to match the dimensionality.
+ ASSERT_NE(interpreter.SetTensorParametersReadOnly(
+ 1, test.type, "", {2}, quant, test.array, 1 * test.size),
+ kTfLiteOk);
+ // Allocating should work since we should have our last correct array
+ // values in place.
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ }
+}
+
+TEST(BasicInterpreter, CheckAlignment) {
+ struct {
+ TfLiteType type;
+ } cases[] = {
+ {kTfLiteFloat32},
+ {kTfLiteInt32},
+ {kTfLiteUInt8},
+ {kTfLiteInt64},
+ };
+
+ for (auto test : cases) {
+ Interpreter interpreter;
+
+ ASSERT_EQ(interpreter.AddTensors(4), kTfLiteOk);
+
+ for (int i = 0; i < 4; i++) {
+ TfLiteQuantizationParams quant;
+ interpreter.SetTensorParametersReadWrite(i, test.type, "", {2 * i + 1},
+ quant);
+ }
+ interpreter.AllocateTensors();
+ for (int i = 0; i < 4; i++) {
+ const TfLiteTensor& tensor = *interpreter.tensor(i);
+ ASSERT_EQ(reinterpret_cast<intptr_t>(tensor.data.raw) % 4, 0);
+ }
+ }
+}
+
+TEST(BasicInterpreter, CheckArenaAllocation) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(10), kTfLiteOk);
+
+ TfLiteQuantizationParams quant;
+ TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+
+ std::vector<int> sizes{2048, 4096, 1023, 2047, 1021,
+ 2047, 1023, 2046, 1021, 2048};
+ for (int i = 0; i < sizes.size(); ++i) {
+ interpreter.SetTensorParametersReadWrite(i, kTfLiteUInt8, "", {sizes[i]},
+ quant);
+ }
+ interpreter.SetInputs({0, 1});
+ interpreter.SetOutputs({9, 4});
+ interpreter.AddNodeWithParameters({0, 1}, {2, 3}, nullptr, 0, nullptr, &reg);
+ interpreter.AddNodeWithParameters({2, 1}, {4, 5}, nullptr, 0, nullptr, &reg);
+ interpreter.AddNodeWithParameters({4, 3}, {6, 7}, nullptr, 0, nullptr, &reg);
+ interpreter.AddNodeWithParameters({6, 5}, {8}, nullptr, 0, nullptr, &reg);
+ interpreter.AddNodeWithParameters({8, 7}, {9}, nullptr, 0, nullptr, &reg);
+
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+
+ ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw);
+ ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw);
+
+ ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw);
+ ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw);
+ ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw);
+
+ ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(3)->data.raw);
+
+ ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw);
+}
+
+TEST(BasicInterpreter, BufferAccess) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
+ 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ // Verify we get a valid pointer.r
+ ASSERT_NE(interpreter.typed_tensor<float>(0), nullptr);
+ // Verify incorrect pointer will not returned.
+ ASSERT_EQ(interpreter.typed_tensor<int>(0), nullptr);
+ // Verify that raw c interface ptr matches safe interface.
+ ASSERT_EQ(interpreter.typed_tensor<float>(0), interpreter.tensor(0)->data.f);
+}
+
+TEST(BasicInterpreter, NoOpInterpreter) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
+
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
+ 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
+ kTfLiteOk);
+
+ ASSERT_EQ(interpreter.ResizeInputTensor(interpreter.inputs()[0], {1, 2, 3}),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+}
+
+TEST(BasicInterpreter, OneOpInterpreter) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk);
+
+ TfLiteQuantizationParams quantized;
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "in1",
+ {3}, quantized),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "out0",
+ {3}, quantized),
+ kTfLiteOk);
+
+ ASSERT_EQ(interpreter.GetInputName(0), "in1");
+ ASSERT_EQ(interpreter.GetOutputName(0), "out0");
+
+ TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+ reg.init = [](TfLiteContext* context, const char*, size_t) -> void* {
+ auto* first_new_tensor = new int;
+ context->AddTensors(context, 2, first_new_tensor);
+ return first_new_tensor;
+ };
+ reg.free = [](TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+ };
+ reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+ auto* first_new_tensor = reinterpret_cast<int*>(node->user_data);
+
+ TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]];
+
+ TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
+ TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor1, newSize));
+
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(2);
+ for (int i = 0; i < 2; ++i) {
+ node->temporaries->data[i] = *(first_new_tensor) + i;
+ }
+
+ auto setup_temporary = [&](int id) {
+ TfLiteTensor* tmp = &context->tensors[id];
+ tmp->type = kTfLiteFloat32;
+ tmp->allocation_type = kTfLiteArenaRw;
+ return context->ResizeTensor(context, tmp,
+ TfLiteIntArrayCopy(tensor0->dims));
+ };
+ TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[0]));
+ TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[1]));
+
+ return kTfLiteOk;
+ };
+ reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
+
+ auto populate = [&](int id) {
+ TfLiteTensor* t = &context->tensors[id];
+ int num = a0->dims->data[0];
+ for (int i = 0; i < num; i++) {
+ t->data.f[i] = a0->data.f[i];
+ }
+ };
+
+ populate(node->outputs->data[0]);
+ populate(node->temporaries->data[0]);
+ populate(node->temporaries->data[1]);
+ return kTfLiteOk;
+ };
+ ASSERT_EQ(
+ interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, &reg),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+}
+
+// Forcefully divides tensor allocation in three steps: one before invocation
+// and two more at invocation time. This happens because we use string tensors
+// and their sizes can't be determined until invocation time.
+TEST(BasicInterpreter, ThreeStepAllocate) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(5), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetOutputs({4}), kTfLiteOk);
+
+ TfLiteQuantizationParams quantized;
+ char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'A', 'B', 'C'};
+ // Read only string tensor.
+ ASSERT_EQ(interpreter.SetTensorParametersReadOnly(0, kTfLiteString, "", {1},
+ quantized, data, 15),
+ kTfLiteOk);
+ // Read-write string tensor.
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteString, "", {1},
+ quantized),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(2, kTfLiteInt32, "", {1},
+ quantized),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(3, kTfLiteString, "", {1},
+ quantized),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(4, kTfLiteInt32, "", {1},
+ quantized),
+ kTfLiteOk);
+
+ // String-in String-out node.
+ TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr};
+ reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
+ DynamicBuffer buf;
+ StringRef str_ref = GetString(a0, 0);
+ buf.AddString(str_ref);
+ buf.WriteToTensor(a1);
+ return kTfLiteOk;
+ };
+
+ // String-in Int-out node.
+ TfLiteRegistration reg_len = {nullptr, nullptr, nullptr, nullptr};
+ reg_len.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
+ outputSize->data[0] = 1;
+ return context->ResizeTensor(context, output, outputSize);
+ };
+ reg_len.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
+ a1->data.i32[0] = a0->bytes;
+ return kTfLiteOk;
+ };
+
+ ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
+ &reg_copy),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.AddNodeWithParameters({1}, {2}, nullptr, 0, nullptr,
+ &reg_len),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {3}, nullptr, 0, nullptr,
+ &reg_copy),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.AddNodeWithParameters({3}, {4}, nullptr, 0, nullptr,
+ &reg_len),
+ kTfLiteOk);
+
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+
+ ASSERT_EQ(interpreter.tensor(0)->bytes, 15);
+ ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr);
+ ASSERT_EQ(interpreter.tensor(1)->bytes, 15);
+ ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr);
+ ASSERT_EQ(interpreter.tensor(3)->bytes, 15);
+ ASSERT_NE(interpreter.tensor(4)->data.raw, nullptr);
+ ASSERT_EQ(interpreter.tensor(2)->bytes, 4);
+ ASSERT_EQ(interpreter.tensor(2)->data.i32[0], 15);
+ ASSERT_EQ(interpreter.tensor(4)->bytes, 4);
+ ASSERT_EQ(interpreter.tensor(4)->data.i32[0], 15);
+}
+
+TEST(BasicInterpreter, AllocateTwice) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk);
+
+ TfLiteQuantizationParams quantized;
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
+ quantized),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
+ quantized),
+ kTfLiteOk);
+
+ TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+ reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]];
+ TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
+ return context->ResizeTensor(context, tensor1, newSize);
+ };
+ reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
+ int num = a0->dims->data[0];
+ for (int i = 0; i < num; i++) {
+ a1->data.f[i] = a0->data.f[i];
+ }
+ return kTfLiteOk;
+ };
+ ASSERT_EQ(
+ interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, &reg),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+ char* old_tensor0_ptr = interpreter.tensor(0)->data.raw;
+ char* old_tensor1_ptr = interpreter.tensor(1)->data.raw;
+
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+ ASSERT_EQ(old_tensor0_ptr, interpreter.tensor(0)->data.raw);
+ ASSERT_EQ(old_tensor1_ptr, interpreter.tensor(1)->data.raw);
+}
+
+struct TestErrorReporter : public ErrorReporter {
+ int Report(const char* format, va_list args) override {
+ char buffer[1024];
+ int size = vsnprintf(buffer, sizeof(buffer), format, args);
+ all_reports += buffer;
+ calls++;
+ return size;
+ }
+ int calls = 0;
+ std::string all_reports;
+};
+
+TEST(BasicInterpreter, TestNullErrorReporter) {
+ TestErrorReporter reporter;
+ Interpreter interpreter;
+}
+
+TEST(BasicInterpreter, TestCustomErrorReporter) {
+ TestErrorReporter reporter;
+ Interpreter interpreter(&reporter);
+ ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
+ ASSERT_EQ(reporter.all_reports, "Invoke called on model that is not ready.");
+ ASSERT_EQ(reporter.calls, 1);
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+#ifdef OS_LINUX
+ FLAGS_logtostderr = true;
+#endif
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD
new file mode 100644
index 0000000000..74fb4fe001
--- /dev/null
+++ b/tensorflow/contrib/lite/java/BUILD
@@ -0,0 +1,111 @@
+# Description:
+# TensorFlow Lite Java API.
+
+package(default_visibility = ["//visibility:private"])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary")
+
+android_library(
+ name = "tensorflowlite",
+ srcs = glob(
+ [
+ "src/main/java/org/tensorflow/lite/*.java",
+ ],
+ ),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":tflite_runtime",
+ "@javax_validation",
+ ],
+)
+
+android_library(
+ name = "tensorflowlite_java",
+ srcs = glob(
+ [
+ "src/main/java/org/tensorflow/lite/*.java",
+ ],
+ ),
+ visibility = ["//visibility:public"],
+ deps = [
+ "@javax_validation",
+ ],
+)
+
+java_library(
+ name = "tensorflowlitelib",
+ srcs = glob(
+ [
+ "src/main/java/org/tensorflow/lite/*.java",
+ ],
+ ),
+ javacopts = JAVACOPTS,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":libtensorflowlite_jni.so",
+ "//tensorflow/contrib/lite/java/src/main/native",
+ "@javax_validation",
+ ],
+)
+
+java_test(
+ name = "TensorFlowLiteTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java"],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.lite.TensorFlowLiteTest",
+ deps = [
+ ":tensorflowlitelib",
+ "@com_google_truth",
+ "@junit",
+ ],
+)
+
+java_test(
+ name = "DataTypeTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/lite/DataTypeTest.java"],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.lite.DataTypeTest",
+ deps = [
+ ":tensorflowlitelib",
+ "@com_google_truth",
+ "@junit",
+ ],
+)
+
+filegroup(
+ name = "libtensorflowlite_jni",
+ srcs = select({
+ "//conditions:default": [":libtensorflowlite_jni.so"],
+ }),
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "tflite_runtime",
+ srcs = ["libtensorflowlite_jni.so"],
+ visibility = ["//visibility:public"],
+)
+
+tflite_jni_binary(
+ name = "libtensorflowlite_jni.so",
+ deps = [
+ "//tensorflow/contrib/lite/java/src/main/native",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/java/demo/.gitignore b/tensorflow/contrib/lite/java/demo/.gitignore
new file mode 100644
index 0000000000..39fb081a42
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/.gitignore
@@ -0,0 +1,9 @@
+*.iml
+.gradle
+/local.properties
+/.idea/workspace.xml
+/.idea/libraries
+.DS_Store
+/build
+/captures
+.externalNativeBuild
diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle
new file mode 100644
index 0000000000..e1470fe717
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/build.gradle
@@ -0,0 +1,58 @@
+apply plugin: 'com.android.application'
+
+android {
+ compileSdkVersion 26
+ buildToolsVersion "26.0.1"
+ defaultConfig {
+ applicationId "android.example.com.tflitecamerademo"
+ minSdkVersion 15
+ targetSdkVersion 26
+ versionCode 1
+ versionName "1.0"
+ testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
+
+ // Remove this block.
+ jackOptions {
+ enabled true
+ }
+ }
+ lintOptions {
+ abortOnError false
+ }
+ buildTypes {
+ release {
+ minifyEnabled false
+ proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
+ }
+ }
+ aaptOptions {
+ noCompress "tflite"
+ }
+
+ compileOptions {
+ sourceCompatibility JavaVersion.VERSION_1_8
+ targetCompatibility JavaVersion.VERSION_1_8
+ }
+}
+
+repositories {
+ flatDir {
+ dirs 'libs'
+ }
+}
+
+dependencies {
+ compile fileTree(dir: 'libs', include: ['*.jar'])
+ androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
+ exclude group: 'com.android.support', module: 'support-annotations'
+ })
+ compile 'com.android.support:appcompat-v7:25.2.0'
+ compile 'com.android.support.constraint:constraint-layout:1.0.2'
+ compile 'com.android.support:design:25.2.0'
+ compile 'com.android.support:support-annotations:25.3.1'
+ compile 'com.android.support:support-v13:25.2.0'
+
+ compile 'org.tensorflow:tensorflow-lite:+'
+
+ testCompile 'junit:junit:4.12'
+}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml
new file mode 100644
index 0000000000..ba63dce5d9
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml
@@ -0,0 +1,42 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2014 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.example.android.tflitecamerademo">
+
+ <uses-permission android:name="android.permission.CAMERA" />
+
+ <uses-feature android:name="android.hardware.camera" />
+ <uses-feature android:name="android.hardware.camera.autofocus" />
+
+ <uses-sdk android:minSdkVersion="21" />
+
+ <application android:allowBackup="true"
+ android:label="@string/app_name"
+ android:icon="@drawable/ic_launcher"
+ android:theme="@style/MaterialTheme">
+
+ <activity android:name="com.example.android.tflitecamerademo.CameraActivity"
+ android:label="@string/app_name">
+ <intent-filter>
+ <action android:name="android.intent.action.MAIN" />
+ <category android:name="android.intent.category.LAUNCHER" />
+ </intent-filter>
+ </activity>
+ </application>
+
+</manifest>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
new file mode 100644
index 0000000000..4fc6d99d8c
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
@@ -0,0 +1,47 @@
+package(default_visibility = ["//visibility:private"])
+
+licenses(["notice"]) # Apache 2.0
+
+android_binary(
+ name = "TfLiteCameraDemo",
+ srcs = glob(["java/**/*.java"]),
+ assets = [
+ ":assets",
+ ],
+ assets_dir = "",
+ custom_package = "com.example.android.tflitecamerademo",
+ manifest = "AndroidManifest.xml",
+ nocompress_extensions = [
+ ".tflite",
+ ],
+ resource_files = glob(["res/**"]),
+ # In some platforms we don't have an Android SDK/NDK and this target
+ # can't be built. We need to prevent the build system from trying to
+ # use the target in that case.
+ tags = ["manual"],
+ deps = [
+ "//tensorflow/contrib/lite/java:tensorflowlite",
+ "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
+ "@androidsdk//com.android.support:support-v13-25.2.0",
+ "@androidsdk//com.android.support:support-v4-25.2.0",
+ ],
+)
+
+filegroup(
+ name = "assets",
+ srcs = [
+ "@tflite_mobilenet//:model_files",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD
new file mode 100644
index 0000000000..1a759f5652
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD
@@ -0,0 +1,26 @@
+package(default_visibility = ["//visibility:private"])
+
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "assets_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "BUILD",
+ ],
+ ),
+ visibility = ["//visibility:public"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt
new file mode 100644
index 0000000000..fe811239d8
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt
@@ -0,0 +1,1001 @@
+background
+tench
+goldfish
+great white shark
+tiger shark
+hammerhead
+electric ray
+stingray
+cock
+hen
+ostrich
+brambling
+goldfinch
+house finch
+junco
+indigo bunting
+robin
+bulbul
+jay
+magpie
+chickadee
+water ouzel
+kite
+bald eagle
+vulture
+great grey owl
+European fire salamander
+common newt
+eft
+spotted salamander
+axolotl
+bullfrog
+tree frog
+tailed frog
+loggerhead
+leatherback turtle
+mud turtle
+terrapin
+box turtle
+banded gecko
+common iguana
+American chameleon
+whiptail
+agama
+frilled lizard
+alligator lizard
+Gila monster
+green lizard
+African chameleon
+Komodo dragon
+African crocodile
+American alligator
+triceratops
+thunder snake
+ringneck snake
+hognose snake
+green snake
+king snake
+garter snake
+water snake
+vine snake
+night snake
+boa constrictor
+rock python
+Indian cobra
+green mamba
+sea snake
+horned viper
+diamondback
+sidewinder
+trilobite
+harvestman
+scorpion
+black and gold garden spider
+barn spider
+garden spider
+black widow
+tarantula
+wolf spider
+tick
+centipede
+black grouse
+ptarmigan
+ruffed grouse
+prairie chicken
+peacock
+quail
+partridge
+African grey
+macaw
+sulphur-crested cockatoo
+lorikeet
+coucal
+bee eater
+hornbill
+hummingbird
+jacamar
+toucan
+drake
+red-breasted merganser
+goose
+black swan
+tusker
+echidna
+platypus
+wallaby
+koala
+wombat
+jellyfish
+sea anemone
+brain coral
+flatworm
+nematode
+conch
+snail
+slug
+sea slug
+chiton
+chambered nautilus
+Dungeness crab
+rock crab
+fiddler crab
+king crab
+American lobster
+spiny lobster
+crayfish
+hermit crab
+isopod
+white stork
+black stork
+spoonbill
+flamingo
+little blue heron
+American egret
+bittern
+crane
+limpkin
+European gallinule
+American coot
+bustard
+ruddy turnstone
+red-backed sandpiper
+redshank
+dowitcher
+oystercatcher
+pelican
+king penguin
+albatross
+grey whale
+killer whale
+dugong
+sea lion
+Chihuahua
+Japanese spaniel
+Maltese dog
+Pekinese
+Shih-Tzu
+Blenheim spaniel
+papillon
+toy terrier
+Rhodesian ridgeback
+Afghan hound
+basset
+beagle
+bloodhound
+bluetick
+black-and-tan coonhound
+Walker hound
+English foxhound
+redbone
+borzoi
+Irish wolfhound
+Italian greyhound
+whippet
+Ibizan hound
+Norwegian elkhound
+otterhound
+Saluki
+Scottish deerhound
+Weimaraner
+Staffordshire bullterrier
+American Staffordshire terrier
+Bedlington terrier
+Border terrier
+Kerry blue terrier
+Irish terrier
+Norfolk terrier
+Norwich terrier
+Yorkshire terrier
+wire-haired fox terrier
+Lakeland terrier
+Sealyham terrier
+Airedale
+cairn
+Australian terrier
+Dandie Dinmont
+Boston bull
+miniature schnauzer
+giant schnauzer
+standard schnauzer
+Scotch terrier
+Tibetan terrier
+silky terrier
+soft-coated wheaten terrier
+West Highland white terrier
+Lhasa
+flat-coated retriever
+curly-coated retriever
+golden retriever
+Labrador retriever
+Chesapeake Bay retriever
+German short-haired pointer
+vizsla
+English setter
+Irish setter
+Gordon setter
+Brittany spaniel
+clumber
+English springer
+Welsh springer spaniel
+cocker spaniel
+Sussex spaniel
+Irish water spaniel
+kuvasz
+schipperke
+groenendael
+malinois
+briard
+kelpie
+komondor
+Old English sheepdog
+Shetland sheepdog
+collie
+Border collie
+Bouvier des Flandres
+Rottweiler
+German shepherd
+Doberman
+miniature pinscher
+Greater Swiss Mountain dog
+Bernese mountain dog
+Appenzeller
+EntleBucher
+boxer
+bull mastiff
+Tibetan mastiff
+French bulldog
+Great Dane
+Saint Bernard
+Eskimo dog
+malamute
+Siberian husky
+dalmatian
+affenpinscher
+basenji
+pug
+Leonberg
+Newfoundland
+Great Pyrenees
+Samoyed
+Pomeranian
+chow
+keeshond
+Brabancon griffon
+Pembroke
+Cardigan
+toy poodle
+miniature poodle
+standard poodle
+Mexican hairless
+timber wolf
+white wolf
+red wolf
+coyote
+dingo
+dhole
+African hunting dog
+hyena
+red fox
+kit fox
+Arctic fox
+grey fox
+tabby
+tiger cat
+Persian cat
+Siamese cat
+Egyptian cat
+cougar
+lynx
+leopard
+snow leopard
+jaguar
+lion
+tiger
+cheetah
+brown bear
+American black bear
+ice bear
+sloth bear
+mongoose
+meerkat
+tiger beetle
+ladybug
+ground beetle
+long-horned beetle
+leaf beetle
+dung beetle
+rhinoceros beetle
+weevil
+fly
+bee
+ant
+grasshopper
+cricket
+walking stick
+cockroach
+mantis
+cicada
+leafhopper
+lacewing
+dragonfly
+damselfly
+admiral
+ringlet
+monarch
+cabbage butterfly
+sulphur butterfly
+lycaenid
+starfish
+sea urchin
+sea cucumber
+wood rabbit
+hare
+Angora
+hamster
+porcupine
+fox squirrel
+marmot
+beaver
+guinea pig
+sorrel
+zebra
+hog
+wild boar
+warthog
+hippopotamus
+ox
+water buffalo
+bison
+ram
+bighorn
+ibex
+hartebeest
+impala
+gazelle
+Arabian camel
+llama
+weasel
+mink
+polecat
+black-footed ferret
+otter
+skunk
+badger
+armadillo
+three-toed sloth
+orangutan
+gorilla
+chimpanzee
+gibbon
+siamang
+guenon
+patas
+baboon
+macaque
+langur
+colobus
+proboscis monkey
+marmoset
+capuchin
+howler monkey
+titi
+spider monkey
+squirrel monkey
+Madagascar cat
+indri
+Indian elephant
+African elephant
+lesser panda
+giant panda
+barracouta
+eel
+coho
+rock beauty
+anemone fish
+sturgeon
+gar
+lionfish
+puffer
+abacus
+abaya
+academic gown
+accordion
+acoustic guitar
+aircraft carrier
+airliner
+airship
+altar
+ambulance
+amphibian
+analog clock
+apiary
+apron
+ashcan
+assault rifle
+backpack
+bakery
+balance beam
+balloon
+ballpoint
+Band Aid
+banjo
+bannister
+barbell
+barber chair
+barbershop
+barn
+barometer
+barrel
+barrow
+baseball
+basketball
+bassinet
+bassoon
+bathing cap
+bath towel
+bathtub
+beach wagon
+beacon
+beaker
+bearskin
+beer bottle
+beer glass
+bell cote
+bib
+bicycle-built-for-two
+bikini
+binder
+binoculars
+birdhouse
+boathouse
+bobsled
+bolo tie
+bonnet
+bookcase
+bookshop
+bottlecap
+bow
+bow tie
+brass
+brassiere
+breakwater
+breastplate
+broom
+bucket
+buckle
+bulletproof vest
+bullet train
+butcher shop
+cab
+caldron
+candle
+cannon
+canoe
+can opener
+cardigan
+car mirror
+carousel
+carpenter's kit
+carton
+car wheel
+cash machine
+cassette
+cassette player
+castle
+catamaran
+CD player
+cello
+cellular telephone
+chain
+chainlink fence
+chain mail
+chain saw
+chest
+chiffonier
+chime
+china cabinet
+Christmas stocking
+church
+cinema
+cleaver
+cliff dwelling
+cloak
+clog
+cocktail shaker
+coffee mug
+coffeepot
+coil
+combination lock
+computer keyboard
+confectionery
+container ship
+convertible
+corkscrew
+cornet
+cowboy boot
+cowboy hat
+cradle
+crane
+crash helmet
+crate
+crib
+Crock Pot
+croquet ball
+crutch
+cuirass
+dam
+desk
+desktop computer
+dial telephone
+diaper
+digital clock
+digital watch
+dining table
+dishrag
+dishwasher
+disk brake
+dock
+dogsled
+dome
+doormat
+drilling platform
+drum
+drumstick
+dumbbell
+Dutch oven
+electric fan
+electric guitar
+electric locomotive
+entertainment center
+envelope
+espresso maker
+face powder
+feather boa
+file
+fireboat
+fire engine
+fire screen
+flagpole
+flute
+folding chair
+football helmet
+forklift
+fountain
+fountain pen
+four-poster
+freight car
+French horn
+frying pan
+fur coat
+garbage truck
+gasmask
+gas pump
+goblet
+go-kart
+golf ball
+golfcart
+gondola
+gong
+gown
+grand piano
+greenhouse
+grille
+grocery store
+guillotine
+hair slide
+hair spray
+half track
+hammer
+hamper
+hand blower
+hand-held computer
+handkerchief
+hard disc
+harmonica
+harp
+harvester
+hatchet
+holster
+home theater
+honeycomb
+hook
+hoopskirt
+horizontal bar
+horse cart
+hourglass
+iPod
+iron
+jack-o'-lantern
+jean
+jeep
+jersey
+jigsaw puzzle
+jinrikisha
+joystick
+kimono
+knee pad
+knot
+lab coat
+ladle
+lampshade
+laptop
+lawn mower
+lens cap
+letter opener
+library
+lifeboat
+lighter
+limousine
+liner
+lipstick
+Loafer
+lotion
+loudspeaker
+loupe
+lumbermill
+magnetic compass
+mailbag
+mailbox
+maillot
+maillot
+manhole cover
+maraca
+marimba
+mask
+matchstick
+maypole
+maze
+measuring cup
+medicine chest
+megalith
+microphone
+microwave
+military uniform
+milk can
+minibus
+miniskirt
+minivan
+missile
+mitten
+mixing bowl
+mobile home
+Model T
+modem
+monastery
+monitor
+moped
+mortar
+mortarboard
+mosque
+mosquito net
+motor scooter
+mountain bike
+mountain tent
+mouse
+mousetrap
+moving van
+muzzle
+nail
+neck brace
+necklace
+nipple
+notebook
+obelisk
+oboe
+ocarina
+odometer
+oil filter
+organ
+oscilloscope
+overskirt
+oxcart
+oxygen mask
+packet
+paddle
+paddlewheel
+padlock
+paintbrush
+pajama
+palace
+panpipe
+paper towel
+parachute
+parallel bars
+park bench
+parking meter
+passenger car
+patio
+pay-phone
+pedestal
+pencil box
+pencil sharpener
+perfume
+Petri dish
+photocopier
+pick
+pickelhaube
+picket fence
+pickup
+pier
+piggy bank
+pill bottle
+pillow
+ping-pong ball
+pinwheel
+pirate
+pitcher
+plane
+planetarium
+plastic bag
+plate rack
+plow
+plunger
+Polaroid camera
+pole
+police van
+poncho
+pool table
+pop bottle
+pot
+potter's wheel
+power drill
+prayer rug
+printer
+prison
+projectile
+projector
+puck
+punching bag
+purse
+quill
+quilt
+racer
+racket
+radiator
+radio
+radio telescope
+rain barrel
+recreational vehicle
+reel
+reflex camera
+refrigerator
+remote control
+restaurant
+revolver
+rifle
+rocking chair
+rotisserie
+rubber eraser
+rugby ball
+rule
+running shoe
+safe
+safety pin
+saltshaker
+sandal
+sarong
+sax
+scabbard
+scale
+school bus
+schooner
+scoreboard
+screen
+screw
+screwdriver
+seat belt
+sewing machine
+shield
+shoe shop
+shoji
+shopping basket
+shopping cart
+shovel
+shower cap
+shower curtain
+ski
+ski mask
+sleeping bag
+slide rule
+sliding door
+slot
+snorkel
+snowmobile
+snowplow
+soap dispenser
+soccer ball
+sock
+solar dish
+sombrero
+soup bowl
+space bar
+space heater
+space shuttle
+spatula
+speedboat
+spider web
+spindle
+sports car
+spotlight
+stage
+steam locomotive
+steel arch bridge
+steel drum
+stethoscope
+stole
+stone wall
+stopwatch
+stove
+strainer
+streetcar
+stretcher
+studio couch
+stupa
+submarine
+suit
+sundial
+sunglass
+sunglasses
+sunscreen
+suspension bridge
+swab
+sweatshirt
+swimming trunks
+swing
+switch
+syringe
+table lamp
+tank
+tape player
+teapot
+teddy
+television
+tennis ball
+thatch
+theater curtain
+thimble
+thresher
+throne
+tile roof
+toaster
+tobacco shop
+toilet seat
+torch
+totem pole
+tow truck
+toyshop
+tractor
+trailer truck
+tray
+trench coat
+tricycle
+trimaran
+tripod
+triumphal arch
+trolleybus
+trombone
+tub
+turnstile
+typewriter keyboard
+umbrella
+unicycle
+upright
+vacuum
+vase
+vault
+velvet
+vending machine
+vestment
+viaduct
+violin
+volleyball
+waffle iron
+wall clock
+wallet
+wardrobe
+warplane
+washbasin
+washer
+water bottle
+water jug
+water tower
+whiskey jug
+whistle
+wig
+window screen
+window shade
+Windsor tie
+wine bottle
+wing
+wok
+wooden spoon
+wool
+worm fence
+wreck
+yawl
+yurt
+web site
+comic book
+crossword puzzle
+street sign
+traffic light
+book jacket
+menu
+plate
+guacamole
+consomme
+hot pot
+trifle
+ice cream
+ice lolly
+French loaf
+bagel
+pretzel
+cheeseburger
+hotdog
+mashed potato
+head cabbage
+broccoli
+cauliflower
+zucchini
+spaghetti squash
+acorn squash
+butternut squash
+cucumber
+artichoke
+bell pepper
+cardoon
+mushroom
+Granny Smith
+strawberry
+orange
+lemon
+fig
+pineapple
+banana
+jackfruit
+custard apple
+pomegranate
+hay
+carbonara
+chocolate sauce
+dough
+meat loaf
+pizza
+potpie
+burrito
+red wine
+espresso
+cup
+eggnog
+alp
+bubble
+cliff
+coral reef
+geyser
+lakeside
+promontory
+sandbar
+seashore
+valley
+volcano
+ballplayer
+groom
+scuba diver
+rapeseed
+daisy
+yellow lady's slipper
+corn
+acorn
+hip
+buckeye
+coral fungus
+agaric
+gyromitra
+stinkhorn
+earthstar
+hen-of-the-woods
+bolete
+ear
+toilet tissue
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java
new file mode 100644
index 0000000000..f204590659
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.tflitecamerademo;
+
+import android.content.Context;
+import android.util.AttributeSet;
+import android.view.TextureView;
+
+/** A {@link TextureView} that can be adjusted to a specified aspect ratio. */
+public class AutoFitTextureView extends TextureView {
+
+ private int mRatioWidth = 0;
+ private int mRatioHeight = 0;
+
+ public AutoFitTextureView(Context context) {
+ this(context, null);
+ }
+
+ public AutoFitTextureView(Context context, AttributeSet attrs) {
+ this(context, attrs, 0);
+ }
+
+ public AutoFitTextureView(Context context, AttributeSet attrs, int defStyle) {
+ super(context, attrs, defStyle);
+ }
+
+ /**
+ * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio
+ * calculated from the parameters. Note that the actual sizes of parameters don't matter, that is,
+ * calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result.
+ *
+ * @param width Relative horizontal size
+ * @param height Relative vertical size
+ */
+ public void setAspectRatio(int width, int height) {
+ if (width < 0 || height < 0) {
+ throw new IllegalArgumentException("Size cannot be negative.");
+ }
+ mRatioWidth = width;
+ mRatioHeight = height;
+ requestLayout();
+ }
+
+ @Override
+ protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) {
+ super.onMeasure(widthMeasureSpec, heightMeasureSpec);
+ int width = MeasureSpec.getSize(widthMeasureSpec);
+ int height = MeasureSpec.getSize(heightMeasureSpec);
+ if (0 == mRatioWidth || 0 == mRatioHeight) {
+ setMeasuredDimension(width, height);
+ } else {
+ if (width < height * mRatioWidth / mRatioHeight) {
+ setMeasuredDimension(width, width * mRatioHeight / mRatioWidth);
+ } else {
+ setMeasuredDimension(height * mRatioWidth / mRatioHeight, height);
+ }
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
new file mode 100644
index 0000000000..74737a8b88
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
@@ -0,0 +1,708 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.tflitecamerademo;
+
+import android.app.Activity;
+import android.app.AlertDialog;
+import android.app.Dialog;
+import android.app.DialogFragment;
+import android.app.Fragment;
+import android.content.Context;
+import android.content.DialogInterface;
+import android.content.pm.PackageInfo;
+import android.content.pm.PackageManager;
+import android.content.res.Configuration;
+import android.graphics.Bitmap;
+import android.graphics.ImageFormat;
+import android.graphics.Matrix;
+import android.graphics.Point;
+import android.graphics.RectF;
+import android.graphics.SurfaceTexture;
+import android.hardware.camera2.CameraAccessException;
+import android.hardware.camera2.CameraCaptureSession;
+import android.hardware.camera2.CameraCharacteristics;
+import android.hardware.camera2.CameraDevice;
+import android.hardware.camera2.CameraManager;
+import android.hardware.camera2.CaptureRequest;
+import android.hardware.camera2.CaptureResult;
+import android.hardware.camera2.TotalCaptureResult;
+import android.hardware.camera2.params.StreamConfigurationMap;
+import android.media.ImageReader;
+import android.os.Bundle;
+import android.os.Handler;
+import android.os.HandlerThread;
+import android.support.annotation.NonNull;
+import android.support.v13.app.FragmentCompat;
+import android.support.v4.content.ContextCompat;
+import android.util.Log;
+import android.util.Size;
+import android.view.LayoutInflater;
+import android.view.Surface;
+import android.view.TextureView;
+import android.view.View;
+import android.view.ViewGroup;
+import android.widget.TextView;
+import android.widget.Toast;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+/** Basic fragments for the Camera. */
+public class Camera2BasicFragment extends Fragment
+ implements FragmentCompat.OnRequestPermissionsResultCallback {
+
+ /** Tag for the {@link Log}. */
+ private static final String TAG = "TfLiteCameraDemo";
+
+ private static final String FRAGMENT_DIALOG = "dialog";
+
+ private static final String HANDLE_THREAD_NAME = "CameraBackground";
+
+ private static final int PERMISSIONS_REQUEST_CODE = 1;
+
+ private final Object lock = new Object();
+ private boolean runClassifier = false;
+ private boolean checkedPermissions = false;
+ private TextView textView;
+ private ImageClassifier classifier;
+
+ /** Max preview width that is guaranteed by Camera2 API */
+ private static final int MAX_PREVIEW_WIDTH = 1920;
+
+ /** Max preview height that is guaranteed by Camera2 API */
+ private static final int MAX_PREVIEW_HEIGHT = 1080;
+
+ /**
+ * {@link TextureView.SurfaceTextureListener} handles several lifecycle events on a {@link
+ * TextureView}.
+ */
+ private final TextureView.SurfaceTextureListener surfaceTextureListener =
+ new TextureView.SurfaceTextureListener() {
+
+ @Override
+ public void onSurfaceTextureAvailable(SurfaceTexture texture, int width, int height) {
+ openCamera(width, height);
+ }
+
+ @Override
+ public void onSurfaceTextureSizeChanged(SurfaceTexture texture, int width, int height) {
+ configureTransform(width, height);
+ }
+
+ @Override
+ public boolean onSurfaceTextureDestroyed(SurfaceTexture texture) {
+ return true;
+ }
+
+ @Override
+ public void onSurfaceTextureUpdated(SurfaceTexture texture) {}
+ };
+
+ /** ID of the current {@link CameraDevice}. */
+ private String cameraId;
+
+ /** An {@link AutoFitTextureView} for camera preview. */
+ private AutoFitTextureView textureView;
+
+ /** A {@link CameraCaptureSession } for camera preview. */
+ private CameraCaptureSession captureSession;
+
+ /** A reference to the opened {@link CameraDevice}. */
+ private CameraDevice cameraDevice;
+
+ /** The {@link android.util.Size} of camera preview. */
+ private Size previewSize;
+
+ /** {@link CameraDevice.StateCallback} is called when {@link CameraDevice} changes its state. */
+ private final CameraDevice.StateCallback stateCallback =
+ new CameraDevice.StateCallback() {
+
+ @Override
+ public void onOpened(@NonNull CameraDevice currentCameraDevice) {
+ // This method is called when the camera is opened. We start camera preview here.
+ cameraOpenCloseLock.release();
+ cameraDevice = currentCameraDevice;
+ createCameraPreviewSession();
+ }
+
+ @Override
+ public void onDisconnected(@NonNull CameraDevice currentCameraDevice) {
+ cameraOpenCloseLock.release();
+ currentCameraDevice.close();
+ cameraDevice = null;
+ }
+
+ @Override
+ public void onError(@NonNull CameraDevice currentCameraDevice, int error) {
+ cameraOpenCloseLock.release();
+ currentCameraDevice.close();
+ cameraDevice = null;
+ Activity activity = getActivity();
+ if (null != activity) {
+ activity.finish();
+ }
+ }
+ };
+
+ /** An additional thread for running tasks that shouldn't block the UI. */
+ private HandlerThread backgroundThread;
+
+ /** A {@link Handler} for running tasks in the background. */
+ private Handler backgroundHandler;
+
+ /** An {@link ImageReader} that handles image capture. */
+ private ImageReader imageReader;
+
+ /** {@link CaptureRequest.Builder} for the camera preview */
+ private CaptureRequest.Builder previewRequestBuilder;
+
+ /** {@link CaptureRequest} generated by {@link #previewRequestBuilder} */
+ private CaptureRequest previewRequest;
+
+ /** A {@link Semaphore} to prevent the app from exiting before closing the camera. */
+ private Semaphore cameraOpenCloseLock = new Semaphore(1);
+
+ /** A {@link CameraCaptureSession.CaptureCallback} that handles events related to capture. */
+ private CameraCaptureSession.CaptureCallback captureCallback =
+ new CameraCaptureSession.CaptureCallback() {
+
+ @Override
+ public void onCaptureProgressed(
+ @NonNull CameraCaptureSession session,
+ @NonNull CaptureRequest request,
+ @NonNull CaptureResult partialResult) {}
+
+ @Override
+ public void onCaptureCompleted(
+ @NonNull CameraCaptureSession session,
+ @NonNull CaptureRequest request,
+ @NonNull TotalCaptureResult result) {}
+ };
+
+ /**
+ * Shows a {@link Toast} on the UI thread for the classification results.
+ *
+ * @param text The message to show
+ */
+ private void showToast(final String text) {
+ final Activity activity = getActivity();
+ if (activity != null) {
+ activity.runOnUiThread(
+ new Runnable() {
+ @Override
+ public void run() {
+ textView.setText(text);
+ }
+ });
+ }
+ }
+
+ /**
+ * Resizes image.
+ *
+ * Attempting to use too large a preview size could exceed the camera bus' bandwidth limitation,
+ * resulting in gorgeous previews but the storage of garbage capture data.
+ *
+ * Given {@code choices} of {@code Size}s supported by a camera, choose the smallest one that is
+ * at least as large as the respective texture view size, and that is at most as large as the
+ * respective max size, and whose aspect ratio matches with the specified value. If such size
+ * doesn't exist, choose the largest one that is at most as large as the respective max size, and
+ * whose aspect ratio matches with the specified value.
+ *
+ * @param choices The list of sizes that the camera supports for the intended output class
+ * @param textureViewWidth The width of the texture view relative to sensor coordinate
+ * @param textureViewHeight The height of the texture view relative to sensor coordinate
+ * @param maxWidth The maximum width that can be chosen
+ * @param maxHeight The maximum height that can be chosen
+ * @param aspectRatio The aspect ratio
+ * @return The optimal {@code Size}, or an arbitrary one if none were big enough
+ */
+ private static Size chooseOptimalSize(
+ Size[] choices,
+ int textureViewWidth,
+ int textureViewHeight,
+ int maxWidth,
+ int maxHeight,
+ Size aspectRatio) {
+
+ // Collect the supported resolutions that are at least as big as the preview Surface
+ List<Size> bigEnough = new ArrayList<>();
+ // Collect the supported resolutions that are smaller than the preview Surface
+ List<Size> notBigEnough = new ArrayList<>();
+ int w = aspectRatio.getWidth();
+ int h = aspectRatio.getHeight();
+ for (Size option : choices) {
+ if (option.getWidth() <= maxWidth
+ && option.getHeight() <= maxHeight
+ && option.getHeight() == option.getWidth() * h / w) {
+ if (option.getWidth() >= textureViewWidth && option.getHeight() >= textureViewHeight) {
+ bigEnough.add(option);
+ } else {
+ notBigEnough.add(option);
+ }
+ }
+ }
+
+ // Pick the smallest of those big enough. If there is no one big enough, pick the
+ // largest of those not big enough.
+ if (bigEnough.size() > 0) {
+ return Collections.min(bigEnough, new CompareSizesByArea());
+ } else if (notBigEnough.size() > 0) {
+ return Collections.max(notBigEnough, new CompareSizesByArea());
+ } else {
+ Log.e(TAG, "Couldn't find any suitable preview size");
+ return choices[0];
+ }
+ }
+
+ public static Camera2BasicFragment newInstance() {
+ return new Camera2BasicFragment();
+ }
+
+ /** Layout the preview and buttons. */
+ @Override
+ public View onCreateView(
+ LayoutInflater inflater, ViewGroup container, Bundle savedInstanceState) {
+ return inflater.inflate(R.layout.fragment_camera2_basic, container, false);
+ }
+
+ /** Connect the buttons to their event handler. */
+ @Override
+ public void onViewCreated(final View view, Bundle savedInstanceState) {
+ textureView = (AutoFitTextureView) view.findViewById(R.id.texture);
+ textView = (TextView) view.findViewById(R.id.text);
+ }
+
+ /** Load the model and labels. */
+ @Override
+ public void onActivityCreated(Bundle savedInstanceState) {
+ super.onActivityCreated(savedInstanceState);
+ try {
+ classifier = new ImageClassifier(getActivity());
+ } catch (IOException e) {
+ Log.e(TAG, "Failed to initialize an image classifier.");
+ }
+ startBackgroundThread();
+ }
+
+ @Override
+ public void onResume() {
+ super.onResume();
+ startBackgroundThread();
+
+ // When the screen is turned off and turned back on, the SurfaceTexture is already
+ // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open
+ // a camera and start preview from here (otherwise, we wait until the surface is ready in
+ // the SurfaceTextureListener).
+ if (textureView.isAvailable()) {
+ openCamera(textureView.getWidth(), textureView.getHeight());
+ } else {
+ textureView.setSurfaceTextureListener(surfaceTextureListener);
+ }
+ }
+
+ @Override
+ public void onPause() {
+ closeCamera();
+ stopBackgroundThread();
+ super.onPause();
+ }
+
+ @Override
+ public void onDestroy() {
+ classifier.close();
+ super.onDestroy();
+ }
+
+ /**
+ * Sets up member variables related to camera.
+ *
+ * @param width The width of available size for camera preview
+ * @param height The height of available size for camera preview
+ */
+ private void setUpCameraOutputs(int width, int height) {
+ Activity activity = getActivity();
+ CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE);
+ try {
+ for (String cameraId : manager.getCameraIdList()) {
+ CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId);
+
+ // We don't use a front facing camera in this sample.
+ Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING);
+ if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) {
+ continue;
+ }
+
+ StreamConfigurationMap map =
+ characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP);
+ if (map == null) {
+ continue;
+ }
+
+ // // For still image captures, we use the largest available size.
+ Size largest =
+ Collections.max(
+ Arrays.asList(map.getOutputSizes(ImageFormat.JPEG)), new CompareSizesByArea());
+ imageReader =
+ ImageReader.newInstance(
+ largest.getWidth(), largest.getHeight(), ImageFormat.JPEG, /*maxImages*/ 2);
+
+ // Find out if we need to swap dimension to get the preview size relative to sensor
+ // coordinate.
+ int displayRotation = activity.getWindowManager().getDefaultDisplay().getRotation();
+ // noinspection ConstantConditions
+ /* Orientation of the camera sensor */
+ int sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION);
+ boolean swappedDimensions = false;
+ switch (displayRotation) {
+ case Surface.ROTATION_0:
+ case Surface.ROTATION_180:
+ if (sensorOrientation == 90 || sensorOrientation == 270) {
+ swappedDimensions = true;
+ }
+ break;
+ case Surface.ROTATION_90:
+ case Surface.ROTATION_270:
+ if (sensorOrientation == 0 || sensorOrientation == 180) {
+ swappedDimensions = true;
+ }
+ break;
+ default:
+ Log.e(TAG, "Display rotation is invalid: " + displayRotation);
+ }
+
+ Point displaySize = new Point();
+ activity.getWindowManager().getDefaultDisplay().getSize(displaySize);
+ int rotatedPreviewWidth = width;
+ int rotatedPreviewHeight = height;
+ int maxPreviewWidth = displaySize.x;
+ int maxPreviewHeight = displaySize.y;
+
+ if (swappedDimensions) {
+ rotatedPreviewWidth = height;
+ rotatedPreviewHeight = width;
+ maxPreviewWidth = displaySize.y;
+ maxPreviewHeight = displaySize.x;
+ }
+
+ if (maxPreviewWidth > MAX_PREVIEW_WIDTH) {
+ maxPreviewWidth = MAX_PREVIEW_WIDTH;
+ }
+
+ if (maxPreviewHeight > MAX_PREVIEW_HEIGHT) {
+ maxPreviewHeight = MAX_PREVIEW_HEIGHT;
+ }
+
+ previewSize =
+ chooseOptimalSize(
+ map.getOutputSizes(SurfaceTexture.class),
+ rotatedPreviewWidth,
+ rotatedPreviewHeight,
+ maxPreviewWidth,
+ maxPreviewHeight,
+ largest);
+
+ // We fit the aspect ratio of TextureView to the size of preview we picked.
+ int orientation = getResources().getConfiguration().orientation;
+ if (orientation == Configuration.ORIENTATION_LANDSCAPE) {
+ textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight());
+ } else {
+ textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth());
+ }
+
+ this.cameraId = cameraId;
+ return;
+ }
+ } catch (CameraAccessException e) {
+ e.printStackTrace();
+ } catch (NullPointerException e) {
+ // Currently an NPE is thrown when the Camera2API is used but not supported on the
+ // device this code runs.
+ ErrorDialog.newInstance(getString(R.string.camera_error))
+ .show(getChildFragmentManager(), FRAGMENT_DIALOG);
+ }
+ }
+
+ private String[] getRequiredPermissions() {
+ Activity activity = getActivity();
+ try {
+ PackageInfo info =
+ activity
+ .getPackageManager()
+ .getPackageInfo(activity.getPackageName(), PackageManager.GET_PERMISSIONS);
+ String[] ps = info.requestedPermissions;
+ if (ps != null && ps.length > 0) {
+ return ps;
+ } else {
+ return new String[0];
+ }
+ } catch (Exception e) {
+ return new String[0];
+ }
+ }
+
+ /** Opens the camera specified by {@link Camera2BasicFragment#cameraId}. */
+ private void openCamera(int width, int height) {
+ if (!checkedPermissions && !allPermissionsGranted()) {
+ FragmentCompat.requestPermissions(this, getRequiredPermissions(), PERMISSIONS_REQUEST_CODE);
+ return;
+ } else {
+ checkedPermissions = true;
+ }
+ setUpCameraOutputs(width, height);
+ configureTransform(width, height);
+ Activity activity = getActivity();
+ CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE);
+ try {
+ if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) {
+ throw new RuntimeException("Time out waiting to lock camera opening.");
+ }
+ manager.openCamera(cameraId, stateCallback, backgroundHandler);
+ } catch (CameraAccessException e) {
+ e.printStackTrace();
+ } catch (InterruptedException e) {
+ throw new RuntimeException("Interrupted while trying to lock camera opening.", e);
+ }
+ }
+
+ private boolean allPermissionsGranted() {
+ for (String permission : getRequiredPermissions()) {
+ if (ContextCompat.checkSelfPermission(getActivity(), permission)
+ != PackageManager.PERMISSION_GRANTED) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ @Override
+ public void onRequestPermissionsResult(
+ int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
+ super.onRequestPermissionsResult(requestCode, permissions, grantResults);
+ }
+
+ /** Closes the current {@link CameraDevice}. */
+ private void closeCamera() {
+ try {
+ cameraOpenCloseLock.acquire();
+ if (null != captureSession) {
+ captureSession.close();
+ captureSession = null;
+ }
+ if (null != cameraDevice) {
+ cameraDevice.close();
+ cameraDevice = null;
+ }
+ if (null != imageReader) {
+ imageReader.close();
+ imageReader = null;
+ }
+ } catch (InterruptedException e) {
+ throw new RuntimeException("Interrupted while trying to lock camera closing.", e);
+ } finally {
+ cameraOpenCloseLock.release();
+ }
+ }
+
+ /** Starts a background thread and its {@link Handler}. */
+ private void startBackgroundThread() {
+ backgroundThread = new HandlerThread(HANDLE_THREAD_NAME);
+ backgroundThread.start();
+ backgroundHandler = new Handler(backgroundThread.getLooper());
+ synchronized (lock) {
+ runClassifier = true;
+ }
+ backgroundHandler.post(periodicClassify);
+ }
+
+ /** Stops the background thread and its {@link Handler}. */
+ private void stopBackgroundThread() {
+ backgroundThread.quitSafely();
+ try {
+ backgroundThread.join();
+ backgroundThread = null;
+ backgroundHandler = null;
+ synchronized (lock) {
+ runClassifier = false;
+ }
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /** Takes photos and classify them periodically. */
+ private Runnable periodicClassify =
+ new Runnable() {
+ @Override
+ public void run() {
+ synchronized (lock) {
+ if (runClassifier) {
+ classifyFrame();
+ }
+ }
+ backgroundHandler.post(periodicClassify);
+ }
+ };
+
+ /** Creates a new {@link CameraCaptureSession} for camera preview. */
+ private void createCameraPreviewSession() {
+ try {
+ SurfaceTexture texture = textureView.getSurfaceTexture();
+ assert texture != null;
+
+ // We configure the size of default buffer to be the size of camera preview we want.
+ texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight());
+
+ // This is the output Surface we need to start preview.
+ Surface surface = new Surface(texture);
+
+ // We set up a CaptureRequest.Builder with the output Surface.
+ previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW);
+ previewRequestBuilder.addTarget(surface);
+
+ // Here, we create a CameraCaptureSession for camera preview.
+ cameraDevice.createCaptureSession(
+ Arrays.asList(surface),
+ new CameraCaptureSession.StateCallback() {
+
+ @Override
+ public void onConfigured(@NonNull CameraCaptureSession cameraCaptureSession) {
+ // The camera is already closed
+ if (null == cameraDevice) {
+ return;
+ }
+
+ // When the session is ready, we start displaying the preview.
+ captureSession = cameraCaptureSession;
+ try {
+ // Auto focus should be continuous for camera preview.
+ previewRequestBuilder.set(
+ CaptureRequest.CONTROL_AF_MODE,
+ CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE);
+
+ // Finally, we start displaying the camera preview.
+ previewRequest = previewRequestBuilder.build();
+ captureSession.setRepeatingRequest(
+ previewRequest, captureCallback, backgroundHandler);
+ } catch (CameraAccessException e) {
+ e.printStackTrace();
+ }
+ }
+
+ @Override
+ public void onConfigureFailed(@NonNull CameraCaptureSession cameraCaptureSession) {
+ showToast("Failed");
+ }
+ },
+ null);
+ } catch (CameraAccessException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /**
+ * Configures the necessary {@link android.graphics.Matrix} transformation to `textureView`. This
+ * method should be called after the camera preview size is determined in setUpCameraOutputs and
+ * also the size of `textureView` is fixed.
+ *
+ * @param viewWidth The width of `textureView`
+ * @param viewHeight The height of `textureView`
+ */
+ private void configureTransform(int viewWidth, int viewHeight) {
+ Activity activity = getActivity();
+ if (null == textureView || null == previewSize || null == activity) {
+ return;
+ }
+ int rotation = activity.getWindowManager().getDefaultDisplay().getRotation();
+ Matrix matrix = new Matrix();
+ RectF viewRect = new RectF(0, 0, viewWidth, viewHeight);
+ RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth());
+ float centerX = viewRect.centerX();
+ float centerY = viewRect.centerY();
+ if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) {
+ bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY());
+ matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL);
+ float scale =
+ Math.max(
+ (float) viewHeight / previewSize.getHeight(),
+ (float) viewWidth / previewSize.getWidth());
+ matrix.postScale(scale, scale, centerX, centerY);
+ matrix.postRotate(90 * (rotation - 2), centerX, centerY);
+ } else if (Surface.ROTATION_180 == rotation) {
+ matrix.postRotate(180, centerX, centerY);
+ }
+ textureView.setTransform(matrix);
+ }
+
+ /** Classifies a frame from the preview stream. */
+ private void classifyFrame() {
+ if (classifier == null || getActivity() == null || cameraDevice == null) {
+ showToast("Uninitialized Classifier or invalid context.");
+ return;
+ }
+ Bitmap bitmap =
+ textureView.getBitmap(ImageClassifier.DIM_IMG_SIZE_X, ImageClassifier.DIM_IMG_SIZE_Y);
+ String textToShow = classifier.classifyFrame(bitmap);
+ bitmap.recycle();
+ showToast(textToShow);
+ }
+
+ /** Compares two {@code Size}s based on their areas. */
+ private static class CompareSizesByArea implements Comparator<Size> {
+
+ @Override
+ public int compare(Size lhs, Size rhs) {
+ // We cast here to ensure the multiplications won't overflow
+ return Long.signum(
+ (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight());
+ }
+ }
+
+ /** Shows an error message dialog. */
+ public static class ErrorDialog extends DialogFragment {
+
+ private static final String ARG_MESSAGE = "message";
+
+ public static ErrorDialog newInstance(String message) {
+ ErrorDialog dialog = new ErrorDialog();
+ Bundle args = new Bundle();
+ args.putString(ARG_MESSAGE, message);
+ dialog.setArguments(args);
+ return dialog;
+ }
+
+ @Override
+ public Dialog onCreateDialog(Bundle savedInstanceState) {
+ final Activity activity = getActivity();
+ return new AlertDialog.Builder(activity)
+ .setMessage(getArguments().getString(ARG_MESSAGE))
+ .setPositiveButton(
+ android.R.string.ok,
+ new DialogInterface.OnClickListener() {
+ @Override
+ public void onClick(DialogInterface dialogInterface, int i) {
+ activity.finish();
+ }
+ })
+ .create();
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java
new file mode 100644
index 0000000000..e7161ddb26
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java
@@ -0,0 +1,35 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.tflitecamerademo;
+
+import android.app.Activity;
+import android.os.Bundle;
+
+/** Main {@code Activity} class for the Camera app. */
+public class CameraActivity extends Activity {
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ setContentView(R.layout.activity_camera);
+ if (null == savedInstanceState) {
+ getFragmentManager()
+ .beginTransaction()
+ .replace(R.id.container, Camera2BasicFragment.newInstance())
+ .commit();
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
new file mode 100644
index 0000000000..e7bad46370
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
@@ -0,0 +1,184 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.tflitecamerademo;
+
+import android.app.Activity;
+import android.content.res.AssetFileDescriptor;
+import android.graphics.Bitmap;
+import android.os.SystemClock;
+import android.util.Log;
+import java.io.BufferedReader;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import org.tensorflow.lite.Interpreter;
+
+/** Classifies images with Tensorflow Lite. */
+public class ImageClassifier {
+
+ /** Tag for the {@link Log}. */
+ private static final String TAG = "TfLiteCameraDemo";
+
+ /** Name of the model file stored in Assets. */
+ private static final String MODEL_PATH = "mobilenet_quant_v1_224.tflite";
+
+ /** Name of the label file stored in Assets. */
+ private static final String LABEL_PATH = "labels.txt";
+
+ /** Number of results to show in the UI. */
+ private static final int RESULTS_TO_SHOW = 3;
+
+ /** Dimensions of inputs. */
+ private static final int DIM_BATCH_SIZE = 1;
+
+ private static final int DIM_PIXEL_SIZE = 3;
+
+ static final int DIM_IMG_SIZE_X = 224;
+ static final int DIM_IMG_SIZE_Y = 224;
+
+ /* Preallocated buffers for storing image data in. */
+ private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];
+
+ /** An instance of the driver class to run model inference with Tensorflow Lite. */
+ private Interpreter tflite;
+
+ /** Labels corresponding to the output of the vision model. */
+ private List<String> labelList;
+
+ /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */
+ private ByteBuffer imgData = null;
+
+ /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */
+ private byte[][] labelProbArray = null;
+
+ private PriorityQueue<Map.Entry<String, Float>> sortedLabels =
+ new PriorityQueue<>(
+ RESULTS_TO_SHOW,
+ new Comparator<Map.Entry<String, Float>>() {
+ @Override
+ public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2) {
+ return (o1.getValue()).compareTo(o2.getValue());
+ }
+ });
+
+ /** Initializes an {@code ImageClassifier}. */
+ ImageClassifier(Activity activity) throws IOException {
+ tflite = new Interpreter(loadModelFile(activity));
+ labelList = loadLabelList(activity);
+ imgData =
+ ByteBuffer.allocateDirect(
+ DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
+ imgData.order(ByteOrder.nativeOrder());
+ labelProbArray = new byte[1][labelList.size()];
+ Log.d(TAG, "Created a Tensorflow Lite Image Classifier.");
+ }
+
+ /** Classifies a frame from the preview stream. */
+ String classifyFrame(Bitmap bitmap) {
+ if (tflite == null) {
+ Log.e(TAG, "Image classifier has not been initialized; Skipped.");
+ return "Uninitialized Classifier.";
+ }
+ convertBitmapToByteBuffer(bitmap);
+ // Here's where the magic happens!!!
+ long startTime = SystemClock.uptimeMillis();
+ tflite.run(imgData, labelProbArray);
+ long endTime = SystemClock.uptimeMillis();
+ Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));
+ String textToShow = printTopKLabels();
+ textToShow = Long.toString(endTime - startTime) + "ms" + textToShow;
+ return textToShow;
+ }
+
+ /** Closes tflite to release resources. */
+ public void close() {
+ tflite.close();
+ tflite = null;
+ }
+
+ /** Reads label list from Assets. */
+ private List<String> loadLabelList(Activity activity) throws IOException {
+ List<String> labelList = new ArrayList<String>();
+ BufferedReader reader =
+ new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH)));
+ String line;
+ while ((line = reader.readLine()) != null) {
+ labelList.add(line);
+ }
+ reader.close();
+ return labelList;
+ }
+
+ /** Memory-map the model file in Assets. */
+ private MappedByteBuffer loadModelFile(Activity activity) throws IOException {
+ AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH);
+ FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
+ FileChannel fileChannel = inputStream.getChannel();
+ long startOffset = fileDescriptor.getStartOffset();
+ long declaredLength = fileDescriptor.getDeclaredLength();
+ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+ }
+
+ /** Writes Image data into a {@code ByteBuffer}. */
+ private void convertBitmapToByteBuffer(Bitmap bitmap) {
+ if (imgData == null) {
+ return;
+ }
+ imgData.rewind();
+ bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+ // Convert the image to floating point.
+ int pixel = 0;
+ long startTime = SystemClock.uptimeMillis();
+ for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
+ for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
+ final int val = intValues[pixel++];
+ imgData.put((byte) ((val >> 16) & 0xFF));
+ imgData.put((byte) ((val >> 8) & 0xFF));
+ imgData.put((byte) (val & 0xFF));
+ }
+ }
+ long endTime = SystemClock.uptimeMillis();
+ Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime));
+ }
+
+ /** Prints top-K labels, to be shown in UI as the results. */
+ private String printTopKLabels() {
+ for (int i = 0; i < labelList.size(); ++i) {
+ sortedLabels.add(
+ new AbstractMap.SimpleEntry<>(labelList.get(i), (labelProbArray[0][i] & 0xff) / 255.0f));
+ if (sortedLabels.size() > RESULTS_TO_SHOW) {
+ sortedLabels.poll();
+ }
+ }
+ String textToShow = "";
+ final int size = sortedLabels.size();
+ for (int i = 0; i < size; ++i) {
+ Map.Entry<String, Float> label = sortedLabels.poll();
+ textToShow = "\n" + label.getKey() + ":" + Float.toString(label.getValue()) + textToShow;
+ }
+ return textToShow;
+ }
+}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png
new file mode 100644
index 0000000000..e0a70008b1
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png
new file mode 100644
index 0000000000..c22509d8df
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png
new file mode 100644
index 0000000000..a84e3ef52c
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png
new file mode 100644
index 0000000000..520c2dd100
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png
new file mode 100644
index 0000000000..d68af39186
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png
new file mode 100644
index 0000000000..1347b09198
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png
new file mode 100644
index 0000000000..15e419b7cc
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png
new file mode 100644
index 0000000000..fd933333b7
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png
new file mode 100644
index 0000000000..342ce34e16
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml
new file mode 100644
index 0000000000..a84f1bbfa0
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml
@@ -0,0 +1,50 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2014 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent">
+
+ <com.example.android.tflitecamerademo.AutoFitTextureView
+ android:id="@+id/texture"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:layout_alignParentBottom="true"
+ android:layout_alignParentStart="true"
+ android:layout_alignParentTop="true" />
+
+ <FrameLayout
+ android:id="@+id/control"
+ android:layout_width="match_parent"
+ android:layout_height="wrap_content"
+ android:layout_alignParentBottom="true"
+ android:layout_alignParentEnd="true"
+ android:layout_alignParentTop="true"
+ android:layout_toRightOf="@id/texture"
+ android:background="@color/control_background"
+ android:orientation="horizontal">
+
+ <TextView android:id="@+id/text"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:paddingTop="20dp"
+ android:textColor="#FFF"
+ android:textSize="20sp"
+ android:textStyle="bold" />
+
+
+ </FrameLayout>
+
+</RelativeLayout>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml
new file mode 100644
index 0000000000..286e549c65
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml
@@ -0,0 +1,22 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2014 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
+ xmlns:tools="http://schemas.android.com/tools"
+ android:id="@+id/container"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:background="#000"
+ tools:context="com.example.android.tflitecamerademo.CameraActivity" />
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml
new file mode 100644
index 0000000000..15305c436e
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml
@@ -0,0 +1,45 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2014 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent">
+
+ <com.example.android.tflitecamerademo.AutoFitTextureView
+ android:id="@+id/texture"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:layout_alignParentStart="true"
+ android:layout_alignParentTop="true" />
+
+ <FrameLayout
+ android:id="@+id/control"
+ android:layout_width="match_parent"
+ android:layout_height="112dp"
+ android:layout_alignParentBottom="true"
+ android:layout_alignParentStart="true"
+ android:background="@color/control_background">
+
+ <TextView android:id="@+id/text"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:paddingLeft="80dp"
+ android:textColor="#FFF"
+ android:textSize="20sp"
+ android:textStyle="bold" />
+
+ </FrameLayout>
+
+</RelativeLayout>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml
new file mode 100644
index 0000000000..22074a2bdb
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml
@@ -0,0 +1,24 @@
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ -->
+
+<resources>
+
+ <!-- Semantic definitions -->
+
+ <dimen name="horizontal_page_margin">@dimen/margin_huge</dimen>
+ <dimen name="vertical_page_margin">@dimen/margin_medium</dimen>
+
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml
new file mode 100644
index 0000000000..03d1974183
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml
@@ -0,0 +1,25 @@
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ -->
+
+<resources>
+
+ <style name="Widget.SampleMessage">
+ <item name="android:textAppearance">?android:textAppearanceLarge</item>
+ <item name="android:lineSpacingMultiplier">1.2</item>
+ <item name="android:shadowDy">-6.5</item>
+ </style>
+
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml
new file mode 100644
index 0000000000..8c1ea66f28
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml
@@ -0,0 +1,22 @@
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ -->
+
+<resources>
+
+ <!-- Activity themes -->
+ <style name="Theme.Base" parent="android:Theme.Holo.Light" />
+
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-colors.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-colors.xml
new file mode 100644
index 0000000000..8b6ec3f85d
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-colors.xml
@@ -0,0 +1,21 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<resources>
+
+
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-template-styles.xml
new file mode 100644
index 0000000000..c778e4f98a
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-template-styles.xml
@@ -0,0 +1,24 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<resources>
+
+ <!-- Activity themes -->
+ <style name="Theme.Base" parent="android:Theme.Material.Light">
+ </style>
+
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml
new file mode 100644
index 0000000000..ab7d3fd496
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml
@@ -0,0 +1,30 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<resources>
+ <string name="app_name">TfLiteCameraDemo</string>
+ <string name="intro_message">
+ <![CDATA[
+
+
+ This sample demonstrates the basic use of TfLite API. Check the source code to see how
+ you can use TfLite for efficient, on-device inference with trained TensorFlow models.
+
+
+ ]]>
+ </string>
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml
new file mode 100644
index 0000000000..4b75d2b2bd
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml
@@ -0,0 +1,19 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+ Copyright 2015 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<resources>
+ <color name="control_background">#cc4285f4</color>
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml
new file mode 100644
index 0000000000..a08ec3eb62
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml
@@ -0,0 +1,24 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2014 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<resources>
+ <string name="picture">Picture</string>
+ <string name="description_info">Info</string>
+ <string name="request_permission">This sample needs camera permission.</string>
+ <string name="camera_error">This device doesn\'t support Camera2 API.</string>
+ <string name="toggle_turn_on">NN:On</string>
+ <string name="toggle_turn_off">NN:Off</string>
+ <string name="toggle">Use NNAPI</string>
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml
new file mode 100644
index 0000000000..3f3bdfb494
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml
@@ -0,0 +1,18 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2014 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<resources>
+ <style name="MaterialTheme" parent="android:Theme.Material.Light.NoActionBar.Fullscreen" />
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-dimens.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-dimens.xml
new file mode 100644
index 0000000000..39e710b5ca
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-dimens.xml
@@ -0,0 +1,32 @@
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ -->
+
+<resources>
+
+ <!-- Define standard dimensions to comply with Holo-style grids and rhythm. -->
+
+ <dimen name="margin_tiny">4dp</dimen>
+ <dimen name="margin_small">8dp</dimen>
+ <dimen name="margin_medium">16dp</dimen>
+ <dimen name="margin_large">32dp</dimen>
+ <dimen name="margin_huge">64dp</dimen>
+
+ <!-- Semantic definitions -->
+
+ <dimen name="horizontal_page_margin">@dimen/margin_medium</dimen>
+ <dimen name="vertical_page_margin">@dimen/margin_medium</dimen>
+
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-styles.xml
new file mode 100644
index 0000000000..6e7d593dd8
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-styles.xml
@@ -0,0 +1,42 @@
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ -->
+
+<resources>
+
+ <!-- Activity themes -->
+
+ <style name="Theme.Base" parent="android:Theme.Light" />
+
+ <style name="Theme.Sample" parent="Theme.Base" />
+
+ <style name="AppTheme" parent="Theme.Sample" />
+ <!-- Widget styling -->
+
+ <style name="Widget" />
+
+ <style name="Widget.SampleMessage">
+ <item name="android:textAppearance">?android:textAppearanceMedium</item>
+ <item name="android:lineSpacingMultiplier">1.1</item>
+ </style>
+
+ <style name="Widget.SampleMessageTile">
+ <item name="android:background">@drawable/tile</item>
+ <item name="android:shadowColor">#7F000000</item>
+ <item name="android:shadowDy">-3.5</item>
+ <item name="android:shadowRadius">2</item>
+ </style>
+
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/build.gradle b/tensorflow/contrib/lite/java/demo/build.gradle
new file mode 100644
index 0000000000..b78a0b86c9
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/build.gradle
@@ -0,0 +1,23 @@
+// Top-level build file where you can add configuration options common to all sub-projects/modules.
+
+buildscript {
+ repositories {
+ jcenter()
+ }
+ dependencies {
+ classpath 'com.android.tools.build:gradle:2.3.1'
+
+ // NOTE: Do not place your application dependencies here; they belong
+ // in the individual module build.gradle files
+ }
+}
+
+allprojects {
+ repositories {
+ jcenter()
+ }
+}
+
+task clean(type: Delete) {
+ delete rootProject.buildDir
+}
diff --git a/tensorflow/contrib/lite/java/demo/gradle.properties b/tensorflow/contrib/lite/java/demo/gradle.properties
new file mode 100644
index 0000000000..aac7c9b461
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/gradle.properties
@@ -0,0 +1,17 @@
+# Project-wide Gradle settings.
+
+# IDE (e.g. Android Studio) users:
+# Gradle settings configured through the IDE *will override*
+# any settings specified in this file.
+
+# For more details on how to configure your build environment visit
+# http://www.gradle.org/docs/current/userguide/build_environment.html
+
+# Specifies the JVM arguments used for the daemon process.
+# The setting is particularly useful for tweaking memory settings.
+org.gradle.jvmargs=-Xmx1536m
+
+# When configured, Gradle will run in incubating parallel mode.
+# This option should only be used with decoupled projects. More details, visit
+# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
+# org.gradle.parallel=true
diff --git a/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar
new file mode 100644
index 0000000000..13372aef5e
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties
new file mode 100644
index 0000000000..fa7a38a0e4
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties
@@ -0,0 +1,6 @@
+#Thu Sep 28 09:01:41 PDT 2017
+distributionBase=GRADLE_USER_HOME
+distributionPath=wrapper/dists
+zipStoreBase=GRADLE_USER_HOME
+zipStorePath=wrapper/dists
+distributionUrl=https\://services.gradle.org/distributions/gradle-3.3-all.zip
diff --git a/tensorflow/contrib/lite/java/demo/gradlew b/tensorflow/contrib/lite/java/demo/gradlew
new file mode 100755
index 0000000000..9d82f78915
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/gradlew
@@ -0,0 +1,160 @@
+#!/usr/bin/env bash
+
+##############################################################################
+##
+## Gradle start up script for UN*X
+##
+##############################################################################
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS=""
+
+APP_NAME="Gradle"
+APP_BASE_NAME=`basename "$0"`
+
+# Use the maximum available, or set MAX_FD != -1 to use that value.
+MAX_FD="maximum"
+
+warn ( ) {
+ echo "$*"
+}
+
+die ( ) {
+ echo
+ echo "$*"
+ echo
+ exit 1
+}
+
+# OS specific support (must be 'true' or 'false').
+cygwin=false
+msys=false
+darwin=false
+case "`uname`" in
+ CYGWIN* )
+ cygwin=true
+ ;;
+ Darwin* )
+ darwin=true
+ ;;
+ MINGW* )
+ msys=true
+ ;;
+esac
+
+# Attempt to set APP_HOME
+# Resolve links: $0 may be a link
+PRG="$0"
+# Need this for relative symlinks.
+while [ -h "$PRG" ] ; do
+ ls=`ls -ld "$PRG"`
+ link=`expr "$ls" : '.*-> \(.*\)$'`
+ if expr "$link" : '/.*' > /dev/null; then
+ PRG="$link"
+ else
+ PRG=`dirname "$PRG"`"/$link"
+ fi
+done
+SAVED="`pwd`"
+cd "`dirname \"$PRG\"`/" >/dev/null
+APP_HOME="`pwd -P`"
+cd "$SAVED" >/dev/null
+
+CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+
+# Determine the Java command to use to start the JVM.
+if [ -n "$JAVA_HOME" ] ; then
+ if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
+ # IBM's JDK on AIX uses strange locations for the executables
+ JAVACMD="$JAVA_HOME/jre/sh/java"
+ else
+ JAVACMD="$JAVA_HOME/bin/java"
+ fi
+ if [ ! -x "$JAVACMD" ] ; then
+ die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+ fi
+else
+ JAVACMD="java"
+ which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+fi
+
+# Increase the maximum file descriptors if we can.
+if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then
+ MAX_FD_LIMIT=`ulimit -H -n`
+ if [ $? -eq 0 ] ; then
+ if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
+ MAX_FD="$MAX_FD_LIMIT"
+ fi
+ ulimit -n $MAX_FD
+ if [ $? -ne 0 ] ; then
+ warn "Could not set maximum file descriptor limit: $MAX_FD"
+ fi
+ else
+ warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
+ fi
+fi
+
+# For Darwin, add options to specify how the application appears in the dock
+if $darwin; then
+ GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
+fi
+
+# For Cygwin, switch paths to Windows format before running java
+if $cygwin ; then
+ APP_HOME=`cygpath --path --mixed "$APP_HOME"`
+ CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
+ JAVACMD=`cygpath --unix "$JAVACMD"`
+
+ # We build the pattern for arguments to be converted via cygpath
+ ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
+ SEP=""
+ for dir in $ROOTDIRSRAW ; do
+ ROOTDIRS="$ROOTDIRS$SEP$dir"
+ SEP="|"
+ done
+ OURCYGPATTERN="(^($ROOTDIRS))"
+ # Add a user-defined pattern to the cygpath arguments
+ if [ "$GRADLE_CYGPATTERN" != "" ] ; then
+ OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
+ fi
+ # Now convert the arguments - kludge to limit ourselves to /bin/sh
+ i=0
+ for arg in "$@" ; do
+ CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
+ CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
+
+ if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
+ eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
+ else
+ eval `echo args$i`="\"$arg\""
+ fi
+ i=$((i+1))
+ done
+ case $i in
+ (0) set -- ;;
+ (1) set -- "$args0" ;;
+ (2) set -- "$args0" "$args1" ;;
+ (3) set -- "$args0" "$args1" "$args2" ;;
+ (4) set -- "$args0" "$args1" "$args2" "$args3" ;;
+ (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
+ (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
+ (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
+ (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
+ (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
+ esac
+fi
+
+# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules
+function splitJvmOpts() {
+ JVM_OPTS=("$@")
+}
+eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS
+JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME"
+
+exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@"
diff --git a/tensorflow/contrib/lite/java/demo/gradlew.bat b/tensorflow/contrib/lite/java/demo/gradlew.bat
new file mode 100644
index 0000000000..8a0b282aa6
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/gradlew.bat
@@ -0,0 +1,90 @@
+@if "%DEBUG%" == "" @echo off
+@rem ##########################################################################
+@rem
+@rem Gradle startup script for Windows
+@rem
+@rem ##########################################################################
+
+@rem Set local scope for the variables with windows NT shell
+if "%OS%"=="Windows_NT" setlocal
+
+@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+set DEFAULT_JVM_OPTS=
+
+set DIRNAME=%~dp0
+if "%DIRNAME%" == "" set DIRNAME=.
+set APP_BASE_NAME=%~n0
+set APP_HOME=%DIRNAME%
+
+@rem Find java.exe
+if defined JAVA_HOME goto findJavaFromJavaHome
+
+set JAVA_EXE=java.exe
+%JAVA_EXE% -version >NUL 2>&1
+if "%ERRORLEVEL%" == "0" goto init
+
+echo.
+echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:findJavaFromJavaHome
+set JAVA_HOME=%JAVA_HOME:"=%
+set JAVA_EXE=%JAVA_HOME%/bin/java.exe
+
+if exist "%JAVA_EXE%" goto init
+
+echo.
+echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:init
+@rem Get command-line arguments, handling Windowz variants
+
+if not "%OS%" == "Windows_NT" goto win9xME_args
+if "%@eval[2+2]" == "4" goto 4NT_args
+
+:win9xME_args
+@rem Slurp the command line arguments.
+set CMD_LINE_ARGS=
+set _SKIP=2
+
+:win9xME_args_slurp
+if "x%~1" == "x" goto execute
+
+set CMD_LINE_ARGS=%*
+goto execute
+
+:4NT_args
+@rem Get arguments from the 4NT Shell from JP Software
+set CMD_LINE_ARGS=%$
+
+:execute
+@rem Setup the command line
+
+set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
+
+@rem Execute Gradle
+"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
+
+:end
+@rem End local scope for the variables with windows NT shell
+if "%ERRORLEVEL%"=="0" goto mainEnd
+
+:fail
+rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
+rem the _cmd.exe /c_ return code!
+if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
+exit /b 1
+
+:mainEnd
+if "%OS%"=="Windows_NT" endlocal
+
+:omega
diff --git a/tensorflow/contrib/lite/java/demo/settings.gradle b/tensorflow/contrib/lite/java/demo/settings.gradle
new file mode 100644
index 0000000000..e7b4def49c
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/settings.gradle
@@ -0,0 +1 @@
+include ':app'
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
new file mode 100644
index 0000000000..d63c299589
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+/** Type of elements in a {@link TfLiteTensor}. */
+enum DataType {
+ /** 32-bit single precision floating point. */
+ FLOAT32(1),
+
+ /** 32-bit signed integer. */
+ INT32(2),
+
+ /** 8-bit unsigned integer. */
+ UINT8(3),
+
+ /** 64-bit signed integer. */
+ INT64(4),
+
+ /** A {@link ByteBuffer}. */
+ BYTEBUFFER(999);
+
+ private final int value;
+
+ DataType(int value) {
+ this.value = value;
+ }
+
+ /** Corresponding value of the kTfLite* enum in the TensorFlow Lite CC API. */
+ int getNumber() {
+ return value;
+ }
+
+ /** Converts an integer to the corresponding type. */
+ static DataType fromNumber(int c) {
+ for (DataType t : values) {
+ if (t.value == c) {
+ return t;
+ }
+ }
+ throw new IllegalArgumentException(
+ "DataType " + c + " is not recognized in Java (version " + TensorFlowLite.version() + ")");
+ }
+
+ /** Returns byte size of the type. */
+ int elemByteSize() {
+ switch (this) {
+ case FLOAT32:
+ return 4;
+ case INT32:
+ return 4;
+ case UINT8:
+ return 1;
+ case INT64:
+ return 8;
+ case BYTEBUFFER:
+ return 1;
+ }
+ throw new IllegalArgumentException("DataType " + this + " is not supported yet");
+ }
+
+ // Cached to avoid copying it
+ private static final DataType[] values = values();
+}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
new file mode 100644
index 0000000000..dd883d69d2
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -0,0 +1,172 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import java.io.File;
+import java.nio.MappedByteBuffer;
+import java.util.HashMap;
+import java.util.Map;
+import javax.validation.constraints.NotNull;
+
+/**
+ * Driver class to drive model inference with TensorFlow Lite.
+ *
+ * <p>A {@code Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations
+ * are executed for model inference.
+ *
+ * <p>For example, if a model takes only one input and returns only one output:
+ *
+ * <pre>{@code
+ * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
+ * interpreter.run(input, output);
+ * }
+ * }</pre>
+ *
+ * <p>If a model takes multiple inputs or outputs:
+ *
+ * <pre>{@code
+ * Object[] inputs = {input0, input1, ...};
+ * Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>();
+ * float[][][] ith_output = new float[3][2][4];
+ * map_of_indices_to_outputs.put(i, ith_output);
+ * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
+ * interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
+ * }
+ * }</pre>
+ *
+ * <p>Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite
+ * model with Toco.
+ *
+ * <p><b>WARNING:</b>Instances of a {@code Interpreter} is <b>not</b> thread-safe. A {@code
+ * Interpreter} owns resources that <b>must</b> be explicitly freed by invoking {@link #close()}
+ */
+public final class Interpreter implements AutoCloseable {
+
+ /**
+ * Initializes a {@code Interpreter}
+ *
+ * @param modelFile: a File of a pre-trained TF Lite model.
+ */
+ public Interpreter(@NotNull File modelFile) {
+ if (modelFile == null) {
+ return;
+ }
+ wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath());
+ }
+
+ /**
+ * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file.
+ *
+ * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
+ * Interpreter}.
+ */
+ public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer) {
+ wrapper = new NativeInterpreterWrapper(mappedByteBuffer);
+ }
+
+ /**
+ * Runs model inference if the model takes only one input, and provides only one output.
+ *
+ * @param input an array or multidimensional array, or a {@link ByteBuffer} of primitive types
+ * including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large
+ * input data. When {@link ByteBuffer} is used, its content should remain unchanged until
+ * model inference is done.
+ * @param output a multidimensional array of output data.
+ */
+ public void run(@NotNull Object input, @NotNull Object output) {
+ Object[] inputs = {input};
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, output);
+ runForMultipleInputsOutputs(inputs, outputs);
+ }
+
+ /**
+ * Runs model inference if the model takes multiple inputs, or returns multiple outputs.
+ *
+ * @param inputs an array of input data. The inputs should be in the same order as inputs of the
+ * model. Each input can be an array or multidimensional array, or a {@link ByteBuffer} of
+ * primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred
+ * way to pass large input data. When {@link ByteBuffer} is used, its content should remain
+ * unchanged until model inference is done.
+ * @param outputs a map mapping output indices to multidimensional arrays of output data. It only
+ * needs to keep entries for the outputs to be used.
+ */
+ public void runForMultipleInputsOutputs(
+ @NotNull Object[] inputs, @NotNull Map<Integer, Object> outputs) {
+ if (wrapper == null) {
+ throw new IllegalStateException("The Interpreter has already been closed.");
+ }
+ Tensor[] tensors = wrapper.run(inputs);
+ if (outputs == null || tensors == null || outputs.size() > tensors.length) {
+ throw new IllegalArgumentException("Outputs do not match with model outputs.");
+ }
+ final int size = tensors.length;
+ for (Integer idx : outputs.keySet()) {
+ if (idx == null || idx < 0 || idx >= size) {
+ throw new IllegalArgumentException(
+ String.format("Invalid index of output %d (should be in range [0, %d))", idx, size));
+ }
+ tensors[idx].copyTo(outputs.get(idx));
+ }
+ }
+
+ /**
+ * Resizes idx-th input of the native model to the given dims.
+ *
+ * <p>IllegalArgumentException will be thrown if it fails to resize.
+ */
+ public void resizeInput(int idx, @NotNull int[] dims) {
+ if (wrapper == null) {
+ throw new IllegalStateException("The Interpreter has already been closed.");
+ }
+ wrapper.resizeInput(idx, dims);
+ }
+
+ /**
+ * Gets index of an input given the op name of the input.
+ *
+ * <p>IllegalArgumentException will be thrown if the op name does not exist in the model file used
+ * to initialize the {@link Interpreter}.
+ */
+ public int getInputIndex(String opName) {
+ if (wrapper == null) {
+ throw new IllegalStateException("The Interpreter has already been closed.");
+ }
+ return wrapper.getInputIndex(opName);
+ }
+
+ /**
+ * Gets index of an output given the op name of the output.
+ *
+ * <p>IllegalArgumentException will be thrown if the op name does not exist in the model file used
+ * to initialize the {@link Interpreter}.
+ */
+ public int getOutputIndex(String opName) {
+ if (wrapper == null) {
+ throw new IllegalStateException("The Interpreter has already been closed.");
+ }
+ return wrapper.getOutputIndex(opName);
+ }
+
+ /** Release resources associated with the {@code Interpreter}. */
+ @Override
+ public void close() {
+ wrapper.close();
+ wrapper = null;
+ }
+
+ NativeInterpreterWrapper wrapper;
+}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
new file mode 100644
index 0000000000..1939a078ad
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -0,0 +1,276 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import java.lang.reflect.Array;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A wrapper wraps native interpreter and controls model execution.
+ *
+ * <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be
+ * explicitly freed by invoking the {@link #close()} method when the {@code
+ * NativeInterpreterWrapper} object is no longer needed.
+ */
+final class NativeInterpreterWrapper implements AutoCloseable {
+
+ NativeInterpreterWrapper(String modelPath) {
+ errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
+ modelHandle = createModel(modelPath, errorHandle);
+ interpreterHandle = createInterpreter(modelHandle);
+ }
+
+ /**
+ * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer}. The
+ * MappedByteBuffer should not be modified after the construction of a {@code
+ * NativeInterpreterWrapper}.
+ */
+ NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) {
+ modelByteBuffer = mappedByteBuffer;
+ errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
+ modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
+ interpreterHandle = createInterpreter(modelHandle);
+ }
+
+ /** Releases resources associated with this {@code NativeInterpreterWrapper}. */
+ @Override
+ public void close() {
+ delete(errorHandle, modelHandle, interpreterHandle);
+ errorHandle = 0;
+ modelHandle = 0;
+ interpreterHandle = 0;
+ modelByteBuffer = null;
+ inputsIndexes = null;
+ outputsIndexes = null;
+ }
+
+ /** Sets inputs, runs model inference and returns outputs. */
+ Tensor[] run(Object[] inputs) {
+ if (inputs == null || inputs.length == 0) {
+ throw new IllegalArgumentException("Invalid inputs. Inputs should not be null or empty.");
+ }
+ int[] dataTypes = new int[inputs.length];
+ Object[] sizes = new Object[inputs.length];
+ int[] numsOfBytes = new int[inputs.length];
+ for (int i = 0; i < inputs.length; ++i) {
+ DataType dataType = dataTypeOf(inputs[i]);
+ dataTypes[i] = dataType.getNumber();
+ if (dataType == DataType.BYTEBUFFER) {
+ ByteBuffer buffer = (ByteBuffer) inputs[i];
+ if (buffer.order() != ByteOrder.nativeOrder()) {
+ throw new IllegalArgumentException(
+ "Invalid ByteBuffer. It shoud use ByteOrder.nativeOrder().");
+ }
+ numsOfBytes[i] = buffer.limit();
+ sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]);
+ } else if (isNonEmptyArray(inputs[i])) {
+ int[] dims = shapeOf(inputs[i]);
+ sizes[i] = dims;
+ numsOfBytes[i] = dataType.elemByteSize() * numElements(dims);
+ } else {
+ throw new IllegalArgumentException(
+ String.format(
+ "%d-th element of the %d inputs is not an array or a ByteBuffer.",
+ i, inputs.length));
+ }
+ }
+ long[] outputsHandles =
+ run(interpreterHandle, errorHandle, sizes, dataTypes, numsOfBytes, inputs);
+ if (outputsHandles == null || outputsHandles.length == 0) {
+ throw new IllegalStateException("Interpreter has no outputs.");
+ }
+ Tensor[] outputs = new Tensor[outputsHandles.length];
+ for (int i = 0; i < outputsHandles.length; ++i) {
+ outputs[i] = Tensor.fromHandle(outputsHandles[i]);
+ }
+ return outputs;
+ }
+
+ /** Resizes dimensions of a specific input. */
+ void resizeInput(int idx, int[] dims) {
+ resizeInput(interpreterHandle, errorHandle, idx, dims);
+ }
+
+ void setUseNNAPI(boolean useNNAPI) {
+ useNNAPI(interpreterHandle, useNNAPI);
+ }
+
+ /** Gets index of an input given its name. */
+ int getInputIndex(String name) {
+ if (inputsIndexes == null) {
+ String[] names = getInputNames(interpreterHandle);
+ inputsIndexes = new HashMap<>();
+ if (names != null) {
+ for (int i = 0; i < names.length; ++i) {
+ inputsIndexes.put(names[i], i);
+ }
+ }
+ }
+ if (inputsIndexes.containsKey(name)) {
+ return inputsIndexes.get(name);
+ } else {
+ throw new IllegalArgumentException(
+ String.format(
+ "%s is not a valid name for any input. The indexes of the inputs are %s",
+ name, inputsIndexes.toString()));
+ }
+ }
+
+ /** Gets index of an output given its name. */
+ int getOutputIndex(String name) {
+ if (outputsIndexes == null) {
+ String[] names = getOutputNames(interpreterHandle);
+ outputsIndexes = new HashMap<>();
+ if (names != null) {
+ for (int i = 0; i < names.length; ++i) {
+ outputsIndexes.put(names[i], i);
+ }
+ }
+ }
+ if (outputsIndexes.containsKey(name)) {
+ return outputsIndexes.get(name);
+ } else {
+ throw new IllegalArgumentException(
+ String.format(
+ "%s is not a valid name for any output. The indexes of the outputs are %s",
+ name, outputsIndexes.toString()));
+ }
+ }
+
+ static int numElements(int[] shape) {
+ if (shape == null) {
+ return 0;
+ }
+ int n = 1;
+ for (int i = 0; i < shape.length; i++) {
+ n *= shape[i];
+ }
+ return n;
+ }
+
+ static boolean isNonEmptyArray(Object o) {
+ return (o != null && o.getClass().isArray() && Array.getLength(o) != 0);
+ }
+
+ /** Returns the type of the data. */
+ static DataType dataTypeOf(Object o) {
+ if (o != null) {
+ Class<?> c = o.getClass();
+ while (c.isArray()) {
+ c = c.getComponentType();
+ }
+ if (float.class.equals(c)) {
+ return DataType.FLOAT32;
+ } else if (int.class.equals(c)) {
+ return DataType.INT32;
+ } else if (byte.class.equals(c)) {
+ return DataType.UINT8;
+ } else if (long.class.equals(c)) {
+ return DataType.INT64;
+ } else if (ByteBuffer.class.isInstance(o)) {
+ return DataType.BYTEBUFFER;
+ }
+ }
+ throw new IllegalArgumentException("cannot resolve DataType of " + o.getClass().getName());
+ }
+
+ /** Returns the shape of an object as an int array. */
+ static int[] shapeOf(Object o) {
+ int size = numDimensions(o);
+ int[] dimensions = new int[size];
+ fillShape(o, 0, dimensions);
+ return dimensions;
+ }
+
+ static int numDimensions(Object o) {
+ if (o == null || !o.getClass().isArray()) {
+ return 0;
+ }
+ if (Array.getLength(o) == 0) {
+ throw new IllegalArgumentException("array lengths cannot be 0.");
+ }
+ return 1 + numDimensions(Array.get(o, 0));
+ }
+
+ static void fillShape(Object o, int dim, int[] shape) {
+ if (shape == null || dim == shape.length) {
+ return;
+ }
+ final int len = Array.getLength(o);
+ if (shape[dim] == 0) {
+ shape[dim] = len;
+ } else if (shape[dim] != len) {
+ throw new IllegalArgumentException(
+ String.format("mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
+ }
+ for (int i = 0; i < len; ++i) {
+ fillShape(Array.get(o, i), dim + 1, shape);
+ }
+ }
+
+ private static final int ERROR_BUFFER_SIZE = 512;
+
+ private long errorHandle;
+
+ private long interpreterHandle;
+
+ private long modelHandle;
+
+ private int inputSize;
+
+ private MappedByteBuffer modelByteBuffer;
+
+ private Map<String, Integer> inputsIndexes;
+
+ private Map<String, Integer> outputsIndexes;
+
+ private static native String[] getInputNames(long interpreterHandle);
+
+ private static native String[] getOutputNames(long interpreterHandle);
+
+ private static native void resizeInput(
+ long interpreterHandle, long errorHandle, int inputIdx, int[] dims);
+
+ private static native void useNNAPI(long interpreterHandle, boolean state);
+
+ private static native long createErrorReporter(int size);
+
+ private static native long createModel(String modelPathOrBuffer, long errorHandle);
+
+ private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle);
+
+ private static native long createInterpreter(long modelHandle);
+
+ private static native long[] run(
+ long interpreterHandle,
+ long errorHandle,
+ Object[] sizes,
+ int[] dtypes,
+ int[] numsOfBytes,
+ Object[] values);
+
+ private static native void delete(long errorHandle, long modelHandle, long interpreterHandle);
+
+ private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes);
+
+ static {
+ TensorFlowLite.init();
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
new file mode 100644
index 0000000000..54ace6c63c
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -0,0 +1,71 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import java.util.Arrays;
+
+/**
+ * A typed multi-dimensional array used in Tensorflow Lite.
+ *
+ * <p>The native handle of a {@code Tensor} belongs to {@code NativeInterpreterWrapper}, thus not
+ * needed to be closed here.
+ */
+final class Tensor {
+
+ static Tensor fromHandle(long nativeHandle) {
+ return new Tensor(nativeHandle);
+ }
+
+ /** Reads Tensor content into an array. */
+ <T> T copyTo(T dst) {
+ if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Cannot convert an TensorFlowLite tensor with type %s to a Java object of "
+ + "type %s (which is compatible with the TensorFlowLite type %s)",
+ dtype, dst.getClass().getName(), NativeInterpreterWrapper.dataTypeOf(dst)));
+ }
+ int[] dstShape = NativeInterpreterWrapper.shapeOf(dst);
+ if (!Arrays.equals(dstShape, shapeCopy)) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Shape of output target %s does not match with the shape of the Tensor %s.",
+ Arrays.toString(dstShape), Arrays.toString(shapeCopy)));
+ }
+ readMultiDimensionalArray(nativeHandle, dst);
+ return dst;
+ }
+
+ final long nativeHandle;
+ final DataType dtype;
+ final int[] shapeCopy;
+
+ private Tensor(long nativeHandle) {
+ this.nativeHandle = nativeHandle;
+ this.dtype = DataType.fromNumber(dtype(nativeHandle));
+ this.shapeCopy = shape(nativeHandle);
+ }
+
+ private static native int dtype(long handle);
+
+ private static native int[] shape(long handle);
+
+ private static native void readMultiDimensionalArray(long handle, Object value);
+
+ static {
+ TensorFlowLite.init();
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
new file mode 100644
index 0000000000..711638a9f9
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+/** Static utility methods loading the TensorFlowLite runtime. */
+public final class TensorFlowLite {
+
+ private static final String LIBNAME = "tensorflowlite_jni";
+
+ private TensorFlowLite() {}
+
+ /** Returns the version of the underlying TensorFlowLite runtime. */
+ public static native String version();
+
+ /**
+ * Load the TensorFlowLite runtime C library.
+ */
+ static boolean init() {
+ try {
+ System.loadLibrary(LIBNAME);
+ return true;
+ } catch (UnsatisfiedLinkError e) {
+ System.err.println("TensorFlowLite: failed to load native library: " + e.getMessage());
+ return false;
+ }
+ }
+
+ static {
+ init();
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java
new file mode 100644
index 0000000000..68e6a0f578
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java
@@ -0,0 +1,17 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/** Defines classes to load and execute TensorFlowLite models. */
+package org.tensorflow.lite;
diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD
new file mode 100644
index 0000000000..15806d57c8
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/BUILD
@@ -0,0 +1,108 @@
+# Description:
+# Java Native Interface (JNI) library intended for implementing the
+# TensorFlow Lite Java API using the TensorFlow Lite CC library.
+
+package(default_visibility = ["//visibility:public"])
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "native_framework_only",
+ srcs = [
+ "exception_jni.cc",
+ "nativeinterpreterwrapper_jni.cc",
+ "tensor_jni.cc",
+ "tensorflow_lite_jni.cc",
+ ] + select({
+ # The Android toolchain makes "jni.h" available in the include path.
+ # For non-Android toolchains, generate jni.h and jni_md.h.
+ "//tensorflow:android": [],
+ "//conditions:default": [
+ ":jni.h",
+ ":jni_md.h",
+ ],
+ }),
+ hdrs = [
+ "exception_jni.h",
+ "nativeinterpreterwrapper_jni.h",
+ "tensor_jni.h",
+ "tensorflow_lite_jni.h",
+ ],
+ copts = tflite_copts(),
+ includes = select({
+ "//tensorflow:android": [],
+ "//conditions:default": ["."],
+ }),
+ linkopts = [
+ "-lm",
+ "-ldl",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ ],
+ alwayslink = 1,
+)
+
+# Silly rules to make
+# #include <jni.h>
+# in the source headers work
+# (in combination with the "includes" attribute of the tf_cuda_library rule
+# above. Not needed when using the Android toolchain).
+#
+# Inspired from:
+# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD
+# but hopefully there is a simpler alternative to this.
+genrule(
+ name = "copy_jni_h",
+ srcs = ["@bazel_tools//tools/jdk:jni_header"],
+ outs = ["jni.h"],
+ cmd = "cp -f $< $@",
+)
+
+genrule(
+ name = "copy_jni_md_h",
+ srcs = select({
+ "//tensorflow:darwin": ["@bazel_tools//tools/jdk:jni_md_header-darwin"],
+ "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"],
+ }),
+ outs = ["jni_md.h"],
+ cmd = "cp -f $< $@",
+)
+
+# This includes all ops. If you want a smaller binary, you should copy and
+# modify builtin_ops_jni.cc. You should then link your binary against both
+# ":native_framework_only" and your own version of ":native_builtin_ops".
+cc_library(
+ name = "native",
+ srcs = [
+ "builtin_ops_jni.cc",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ ":native_framework_only",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ],
+ alwayslink = 1,
+)
+
+exports_files(
+ [
+ "version_script.lds",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc b/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc
new file mode 100644
index 0000000000..cce356370f
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc
@@ -0,0 +1,29 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/kernels/register.h"
+
+namespace tflite {
+
+// The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in
+// the tflite namespace. This one instantiates a BuiltinOpResolver, with all the
+// builtin ops. For smaller binary sizes users should avoid linking this in, and
+// should provide a custom make CreateOpResolver() instead.
+std::unique_ptr<OpResolver> CreateOpResolver() { // NOLINT
+ return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>(
+ new tflite::ops::builtin::BuiltinOpResolver());
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc b/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc
new file mode 100644
index 0000000000..1578c9e3dd
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc
@@ -0,0 +1,66 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h"
+
+const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException";
+const char kIllegalStateException[] = "java/lang/IllegalStateException";
+const char kNullPointerException[] = "java/lang/NullPointerException";
+const char kIndexOutOfBoundsException[] = "java/lang/IndexOutOfBoundsException";
+const char kUnsupportedOperationException[] =
+ "java/lang/UnsupportedOperationException";
+
+void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) {
+ va_list args;
+ va_start(args, fmt);
+ const size_t max_msg_len = 512;
+ auto* message = static_cast<char*>(malloc(max_msg_len));
+ if (vsnprintf(message, max_msg_len, fmt, args) >= 0) {
+ env->ThrowNew(env->FindClass(clazz), message);
+ } else {
+ env->ThrowNew(env->FindClass(clazz), "");
+ }
+ free(message);
+ va_end(args);
+}
+
+BufferErrorReporter::BufferErrorReporter(JNIEnv* env, int limit) {
+ buffer_ = new char[limit];
+ if (!buffer_) {
+ throwException(env, kNullPointerException,
+ "Malloc of BufferErrorReporter to hold %d char failed.",
+ limit);
+ return;
+ }
+ start_idx_ = 0;
+ end_idx_ = limit - 1;
+}
+
+BufferErrorReporter::~BufferErrorReporter() { delete[] buffer_; }
+
+int BufferErrorReporter::Report(const char* format, va_list args) {
+ int size = 0;
+ if (start_idx_ < end_idx_) {
+ size = vsnprintf(buffer_ + start_idx_, end_idx_ - start_idx_, format, args);
+ }
+ start_idx_ += size;
+ return size;
+}
+
+const char* BufferErrorReporter::CachedErrorMessage() { return buffer_; }
diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.h b/tensorflow/contrib/lite/java/src/main/native/exception_jni.h
new file mode 100644
index 0000000000..3ffff052df
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/exception_jni.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_
+#define TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_
+
+#include <jni.h>
+#include "tensorflow/contrib/lite/error_reporter.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+extern const char kIllegalArgumentException[];
+extern const char kIllegalStateException[];
+extern const char kNullPointerException[];
+extern const char kIndexOutOfBoundsException[];
+extern const char kUnsupportedOperationException[];
+
+void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...);
+
+class BufferErrorReporter : public tflite::ErrorReporter {
+ public:
+ BufferErrorReporter(JNIEnv* env, int limit);
+ virtual ~BufferErrorReporter();
+ int Report(const char* format, va_list args) override;
+ const char* CachedErrorMessage();
+
+ private:
+ char* buffer_;
+ int start_idx_ = 0;
+ int end_idx_ = 0;
+};
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
new file mode 100644
index 0000000000..bc6462eb54
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -0,0 +1,446 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h"
+
+namespace {
+
+const int kByteBufferValue = 999;
+const int kBufferSize = 256;
+
+tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
+ if (handle == 0) {
+ throwException(env, kIllegalArgumentException,
+ "Invalid handle to Interpreter.");
+ return nullptr;
+ }
+ return reinterpret_cast<tflite::Interpreter*>(handle);
+}
+
+tflite::FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) {
+ if (handle == 0) {
+ throwException(env, kIllegalArgumentException, "Invalid handle to model.");
+ return nullptr;
+ }
+ return reinterpret_cast<tflite::FlatBufferModel*>(handle);
+}
+
+BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) {
+ if (handle == 0) {
+ throwException(env, kIllegalArgumentException,
+ "Invalid handle to ErrorReporter.");
+ return nullptr;
+ }
+ return reinterpret_cast<BufferErrorReporter*>(handle);
+}
+
+std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
+ int size = static_cast<int>(env->GetArrayLength(inputs));
+ std::vector<int> outputs(size, 0);
+ jint* ptr = env->GetIntArrayElements(inputs, nullptr);
+ if (ptr == nullptr) {
+ throwException(env, kIllegalArgumentException,
+ "Empty dimensions of input array.");
+ return {};
+ }
+ for (int i = 0; i < size; ++i) {
+ outputs[i] = ptr[i];
+ }
+ env->ReleaseIntArrayElements(inputs, ptr, JNI_ABORT);
+ return outputs;
+}
+
+bool isByteBuffer(jint data_type) { return data_type == kByteBufferValue; }
+
+TfLiteType resolveDataType(jint data_type) {
+ switch (data_type) {
+ case 1:
+ return kTfLiteFloat32;
+ case 2:
+ return kTfLiteInt32;
+ case 3:
+ return kTfLiteUInt8;
+ case 4:
+ return kTfLiteInt64;
+ default:
+ return kTfLiteNoType;
+ }
+}
+
+void printDims(char* buffer, int max_size, int* dims, int num_dims) {
+ if (max_size <= 0) return;
+ buffer[0] = '?';
+ int size = 1;
+ for (int i = 1; i < num_dims; ++i) {
+ if (max_size > size) {
+ int written_size =
+ snprintf(buffer + size, max_size - size, ",%d", dims[i]);
+ if (written_size < 0) return;
+ size += written_size;
+ }
+ }
+}
+
+TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter,
+ const int input_size, jintArray data_types,
+ jintArray nums_of_bytes, jobjectArray values,
+ jobjectArray sizes) {
+ if (input_size != interpreter->inputs().size()) {
+ throwException(env, kIllegalArgumentException,
+ "Expected num of inputs is %d but got %d",
+ interpreter->inputs().size(), input_size);
+ return kTfLiteError;
+ }
+ if (input_size != env->GetArrayLength(data_types) ||
+ input_size != env->GetArrayLength(nums_of_bytes) ||
+ input_size != env->GetArrayLength(values)) {
+ throwException(env, kIllegalArgumentException,
+ "Arrays in arguments should be of the same length, but got "
+ "%d sizes, %d data_types, %d nums_of_bytes, and %d values",
+ input_size, env->GetArrayLength(data_types),
+ env->GetArrayLength(nums_of_bytes),
+ env->GetArrayLength(values));
+ return kTfLiteError;
+ }
+ for (int i = 0; i < input_size; ++i) {
+ int input_idx = interpreter->inputs()[i];
+ TfLiteTensor* target = interpreter->tensor(input_idx);
+ jintArray dims =
+ static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
+ int num_dims = static_cast<int>(env->GetArrayLength(dims));
+ if (target->dims->size != num_dims) {
+ throwException(env, kIllegalArgumentException,
+ "%d-th input should have %d dimensions, but found %d "
+ "dimensions",
+ i, target->dims->size, num_dims);
+ return kTfLiteError;
+ }
+ jint* ptr = env->GetIntArrayElements(dims, nullptr);
+ for (int j = 1; j < num_dims; ++j) {
+ if (target->dims->data[j] != ptr[j]) {
+ std::unique_ptr<char[]> expected_dims(new char[kBufferSize]);
+ std::unique_ptr<char[]> obtained_dims(new char[kBufferSize]);
+ printDims(expected_dims.get(), kBufferSize, target->dims->data,
+ num_dims);
+ printDims(obtained_dims.get(), kBufferSize, ptr, num_dims);
+ throwException(env, kIllegalArgumentException,
+ "%d-th input dimension should be [%s], but found [%s]",
+ i, expected_dims.get(), obtained_dims.get());
+ env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT);
+ return kTfLiteError;
+ }
+ }
+ env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT);
+ env->DeleteLocalRef(dims);
+ if (env->ExceptionCheck()) return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter,
+ int input_size, jobjectArray sizes) {
+ for (int i = 0; i < input_size; ++i) {
+ int input_idx = interpreter->inputs()[i];
+ jintArray dims =
+ static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
+ TfLiteStatus status = interpreter->ResizeInputTensor(
+ input_idx, convertJIntArrayToVector(env, dims));
+ if (status != kTfLiteOk) {
+ return status;
+ }
+ env->DeleteLocalRef(dims);
+ if (env->ExceptionCheck()) return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter,
+ int input_size, jintArray data_types,
+ jintArray nums_of_bytes, jobjectArray values) {
+ jint* data_type = env->GetIntArrayElements(data_types, nullptr);
+ jint* num_bytes = env->GetIntArrayElements(nums_of_bytes, nullptr);
+ for (int i = 0; i < input_size; ++i) {
+ int input_idx = interpreter->inputs()[i];
+ TfLiteTensor* target = interpreter->tensor(input_idx);
+ jobject value = env->GetObjectArrayElement(values, i);
+ bool is_byte_buffer = isByteBuffer(data_type[i]);
+ if (is_byte_buffer) {
+ writeByteBuffer(env, value, &(target->data.raw),
+ static_cast<int>(num_bytes[i]));
+ } else {
+ TfLiteType type = resolveDataType(data_type[i]);
+ if (type != target->type) {
+ throwException(env, kIllegalArgumentException,
+ "DataType (%d) of input data does not match with the "
+ "DataType (%d) of model inputs.",
+ type, target->type);
+ return kTfLiteError;
+ }
+ writeMultiDimensionalArray(env, value, target->type, target->dims->size,
+ &(target->data.raw),
+ static_cast<int>(num_bytes[i]));
+ }
+ env->DeleteLocalRef(value);
+ if (env->ExceptionCheck()) return kTfLiteError;
+ }
+ env->ReleaseIntArrayElements(data_types, data_type, JNI_ABORT);
+ env->ReleaseIntArrayElements(nums_of_bytes, num_bytes, JNI_ABORT);
+ return kTfLiteOk;
+}
+
+} // namespace
+
+JNIEXPORT jobjectArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return nullptr;
+ jclass string_class = env->FindClass("java/lang/String");
+ if (string_class == nullptr) {
+ throwException(env, kUnsupportedOperationException,
+ "Can not find java/lang/String class to get input names.");
+ return nullptr;
+ }
+ size_t size = interpreter->inputs().size();
+ jobjectArray names = static_cast<jobjectArray>(
+ env->NewObjectArray(size, string_class, env->NewStringUTF("")));
+ for (int i = 0; i < size; ++i) {
+ env->SetObjectArrayElement(names, i,
+ env->NewStringUTF(interpreter->GetInputName(i)));
+ }
+ return names;
+}
+
+JNIEXPORT jobjectArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return nullptr;
+ jclass string_class = env->FindClass("java/lang/String");
+ if (string_class == nullptr) {
+ throwException(env, kUnsupportedOperationException,
+ "Can not find java/lang/String class to get output names.");
+ return nullptr;
+ }
+ size_t size = interpreter->outputs().size();
+ jobjectArray names = static_cast<jobjectArray>(
+ env->NewObjectArray(size, string_class, env->NewStringUTF("")));
+ for (int i = 0; i < size; ++i) {
+ env->SetObjectArrayElement(
+ names, i, env->NewStringUTF(interpreter->GetOutputName(i)));
+ }
+ return names;
+}
+
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jboolean state) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return;
+ interpreter->UseNNAPI(static_cast<bool>(state));
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
+ JNIEnv* env, jclass clazz, jint size) {
+ BufferErrorReporter* error_reporter =
+ new BufferErrorReporter(env, static_cast<int>(size));
+ return reinterpret_cast<jlong>(error_reporter);
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
+ JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) {
+ BufferErrorReporter* error_reporter =
+ convertLongToErrorReporter(env, error_handle);
+ if (error_reporter == nullptr) return 0;
+ const char* path = env->GetStringUTFChars(model_file, nullptr);
+ auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter);
+ if (!model) {
+ throwException(env, kIllegalArgumentException,
+ "Contents of %s does not encode a valid TensorFlowLite "
+ "model: %s",
+ path, error_reporter->CachedErrorMessage());
+ env->ReleaseStringUTFChars(model_file, path);
+ return 0;
+ }
+ env->ReleaseStringUTFChars(model_file, path);
+ return reinterpret_cast<jlong>(model.release());
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
+ JNIEnv* env, jclass /*clazz*/, jobject model_buffer, jlong error_handle) {
+ BufferErrorReporter* error_reporter =
+ convertLongToErrorReporter(env, error_handle);
+ if (error_reporter == nullptr) return 0;
+ const char* buf =
+ static_cast<char*>(env->GetDirectBufferAddress(model_buffer));
+ jlong capacity = env->GetDirectBufferCapacity(model_buffer);
+ auto model = tflite::FlatBufferModel::BuildFromBuffer(
+ buf, static_cast<size_t>(capacity), error_reporter);
+ if (!model) {
+ throwException(env, kIllegalArgumentException,
+ "MappedByteBuffer does not encode a valid TensorFlowLite "
+ "model: %s",
+ error_reporter->CachedErrorMessage());
+ return 0;
+ }
+ return reinterpret_cast<jlong>(model.release());
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
+ JNIEnv* env, jclass clazz, jlong model_handle) {
+ tflite::FlatBufferModel* model = convertLongToModel(env, model_handle);
+ if (model == nullptr) return 0;
+ auto resolver = ::tflite::CreateOpResolver();
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter);
+ return reinterpret_cast<jlong>(interpreter.release());
+}
+
+// Sets inputs, runs inference, and returns outputs as long handles.
+JNIEXPORT jlongArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
+ jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes,
+ jobjectArray values) {
+ tflite::Interpreter* interpreter =
+ convertLongToInterpreter(env, interpreter_handle);
+ if (interpreter == nullptr) return nullptr;
+ BufferErrorReporter* error_reporter =
+ convertLongToErrorReporter(env, error_handle);
+ if (error_reporter == nullptr) return nullptr;
+ const int input_size = env->GetArrayLength(sizes);
+ // validates inputs
+ TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types,
+ nums_of_bytes, values, sizes);
+ if (status != kTfLiteOk) return nullptr;
+ // resizes inputs
+ status = resizeInputs(env, interpreter, input_size, sizes);
+ if (status != kTfLiteOk) {
+ throwException(env, kNullPointerException, "Can not resize the input: %s",
+ error_reporter->CachedErrorMessage());
+ return nullptr;
+ }
+ // allocates memory
+ status = interpreter->AllocateTensors();
+ if (status != kTfLiteOk) {
+ throwException(env, kNullPointerException,
+ "Can not allocate memory for the given inputs: %s",
+ error_reporter->CachedErrorMessage());
+ return nullptr;
+ }
+ // sets inputs
+ status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes,
+ values);
+ if (status != kTfLiteOk) return nullptr;
+ // runs inference
+ if (interpreter->Invoke() != kTfLiteOk) {
+ throwException(env, kIllegalArgumentException,
+ "Failed to run on the given Interpreter: %s",
+ error_reporter->CachedErrorMessage());
+ return nullptr;
+ }
+ // returns outputs
+ const std::vector<int>& results = interpreter->outputs();
+ if (results.empty()) {
+ throwException(env, kIllegalArgumentException,
+ "The Interpreter does not have any outputs.");
+ return nullptr;
+ }
+ jlongArray outputs = env->NewLongArray(results.size());
+ size_t size = results.size();
+ for (int i = 0; i < size; ++i) {
+ TfLiteTensor* source = interpreter->tensor(results[i]);
+ jlong output = reinterpret_cast<jlong>(source);
+ env->SetLongArrayRegion(outputs, i, 1, &output);
+ }
+ return outputs;
+}
+
+JNIEXPORT jintArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
+ JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return nullptr;
+ const int idx = static_cast<int>(input_idx);
+ if (input_idx >= interpreter->inputs().size()) {
+ throwException(env, kIllegalArgumentException,
+ "Out of range: Failed to get %d-th input out of %d inputs",
+ input_idx, interpreter->inputs().size());
+ return nullptr;
+ }
+ TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]);
+ int size = target->dims->size;
+ int expected_num_bytes = elementByteSize(target->type);
+ for (int i = 0; i < size; ++i) {
+ expected_num_bytes *= target->dims->data[i];
+ }
+ if (num_bytes != expected_num_bytes) {
+ throwException(env, kIllegalArgumentException,
+ "Failed to get input dimensions. %d-th input should have"
+ " %d bytes, but found %d bytes.",
+ idx, expected_num_bytes, num_bytes);
+ return nullptr;
+ }
+ jintArray outputs = env->NewIntArray(size);
+ env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0]));
+ return outputs;
+}
+
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
+ jint input_idx, jintArray dims) {
+ BufferErrorReporter* error_reporter =
+ convertLongToErrorReporter(env, error_handle);
+ if (error_reporter == nullptr) return;
+ tflite::Interpreter* interpreter =
+ convertLongToInterpreter(env, interpreter_handle);
+ if (interpreter == nullptr) return;
+ const int idx = static_cast<int>(input_idx);
+ if (idx < 0 || idx >= interpreter->inputs().size()) {
+ throwException(env, kIllegalArgumentException,
+ "Can not resize %d-th input for a model having %d inputs.",
+ idx, interpreter->inputs().size());
+ }
+ TfLiteStatus status = interpreter->ResizeInputTensor(
+ interpreter->inputs()[idx], convertJIntArrayToVector(env, dims));
+ if (status != kTfLiteOk) {
+ throwException(env, kIllegalArgumentException,
+ "Failed to resize %d-th input: %s", idx,
+ error_reporter->CachedErrorMessage());
+ }
+}
+
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete(
+ JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle,
+ jlong interpreter_handle) {
+ if (interpreter_handle != 0) {
+ delete convertLongToInterpreter(env, interpreter_handle);
+ }
+ if (model_handle != 0) {
+ delete convertLongToModel(env, model_handle);
+ }
+ if (error_handle != 0) {
+ delete convertLongToErrorReporter(env, error_handle);
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
new file mode 100644
index 0000000000..430886b7cc
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -0,0 +1,151 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_
+#define TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_
+
+#include <jni.h>
+#include <stdio.h>
+#include <vector>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h"
+#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+// This is to be provided at link-time by a library.
+extern std::unique_ptr<OpResolver> CreateOpResolver();
+} // namespace tflite
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (J)[Ljava/lang/Object;
+ */
+JNIEXPORT jobjectArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (J)[Ljava/lang/Object;
+ */
+JNIEXPORT jobjectArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (JZ)
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jboolean state);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (I)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
+ JNIEnv* env, jclass clazz, jint size);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (Ljava/lang/String;J)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
+ JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (Ljava/lang/Object;J)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
+ JNIEnv* env, jclass clazz, jobject model_buffer, jlong error_handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (J)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
+ JNIEnv* env, jclass clazz, jlong model_handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;)[J
+ */
+JNIEXPORT jlongArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
+ jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes,
+ jobjectArray values);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (JII)[I
+ *
+ * It gets input dimensions if num_bytes matches number of bytes required by
+ * the input, else returns null and throws IllegalArgumentException.
+ */
+JNIEXPORT jintArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
+ JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (JJI[I)
+ *
+ * It resizes dimensions of a input.
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
+ jint input_idx, jintArray dims);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (JJJ)
+ */
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete(
+ JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle,
+ jlong interpreter_handle);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
new file mode 100644
index 0000000000..65126e78a3
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
@@ -0,0 +1,242 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h"
+#include <cstring>
+#include <memory>
+#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h"
+
+namespace {
+
+TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) {
+ if (handle == 0) {
+ throwException(env, kIllegalArgumentException,
+ "Invalid handle to TfLiteTensor.");
+ return nullptr;
+ }
+ return reinterpret_cast<TfLiteTensor*>(handle);
+}
+
+size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type,
+ void* dst, size_t dst_size) {
+ jarray array = static_cast<jarray>(object);
+ const int num_elements = env->GetArrayLength(array);
+ size_t to_copy = num_elements * elementByteSize(type);
+ if (to_copy > dst_size) {
+ throwException(env, kIllegalStateException,
+ "cannot write Java array of %d bytes to Tensor of %d bytes",
+ to_copy, dst_size);
+ return 0;
+ }
+ switch (type) {
+ case kTfLiteFloat32: {
+ jfloatArray a = static_cast<jfloatArray>(array);
+ jfloat* values = env->GetFloatArrayElements(a, nullptr);
+ memcpy(dst, values, to_copy);
+ env->ReleaseFloatArrayElements(a, values, JNI_ABORT);
+ return to_copy;
+ }
+ case kTfLiteInt32: {
+ jintArray a = static_cast<jintArray>(array);
+ jint* values = env->GetIntArrayElements(a, nullptr);
+ memcpy(dst, values, to_copy);
+ env->ReleaseIntArrayElements(a, values, JNI_ABORT);
+ return to_copy;
+ }
+ case kTfLiteInt64: {
+ jlongArray a = static_cast<jlongArray>(array);
+ jlong* values = env->GetLongArrayElements(a, nullptr);
+ memcpy(dst, values, to_copy);
+ env->ReleaseLongArrayElements(a, values, JNI_ABORT);
+ return to_copy;
+ }
+ case kTfLiteUInt8: {
+ jbyteArray a = static_cast<jbyteArray>(array);
+ jbyte* values = env->GetByteArrayElements(a, nullptr);
+ memcpy(dst, values, to_copy);
+ env->ReleaseByteArrayElements(a, values, JNI_ABORT);
+ return to_copy;
+ }
+ default: {
+ throwException(env, kUnsupportedOperationException,
+ "TensorFlowLite currently supports float (32 bits), "
+ "int (32 bits), byte (8 bits), and long (64 bits), "
+ "support for other types (DataType %d in this case) will "
+ "be added in the future",
+ kTfLiteFloat32, type);
+ return 0;
+ }
+ }
+}
+
+size_t readOneDimensionalArray(JNIEnv* env, TfLiteType data_type,
+ const void* src, size_t src_size, jarray dst) {
+ const int len = env->GetArrayLength(dst);
+ const size_t size = len * elementByteSize(data_type);
+ if (size > src_size) {
+ throwException(
+ env, kIllegalStateException,
+ "cannot fill a Java array of %d bytes with a Tensor of %d bytes", size,
+ src_size);
+ return 0;
+ }
+ switch (data_type) {
+ case kTfLiteFloat32: {
+ jfloatArray float_array = static_cast<jfloatArray>(dst);
+ env->SetFloatArrayRegion(float_array, 0, len,
+ static_cast<const jfloat*>(src));
+ return size;
+ }
+ case kTfLiteInt32: {
+ jintArray int_array = static_cast<jintArray>(dst);
+ env->SetIntArrayRegion(int_array, 0, len, static_cast<const jint*>(src));
+ return size;
+ }
+ case kTfLiteInt64: {
+ jlongArray long_array = static_cast<jlongArray>(dst);
+ env->SetLongArrayRegion(long_array, 0, len,
+ static_cast<const jlong*>(src));
+ return size;
+ }
+ case kTfLiteUInt8: {
+ jbyteArray byte_array = static_cast<jbyteArray>(dst);
+ env->SetByteArrayRegion(byte_array, 0, len,
+ static_cast<const jbyte*>(src));
+ return size;
+ }
+ default: {
+ throwException(env, kIllegalStateException, "invalid DataType(%d)",
+ data_type);
+ }
+ }
+ return 0;
+}
+
+size_t readMultiDimensionalArray(JNIEnv* env, TfLiteType data_type, char* src,
+ size_t src_size, int dims_left, jarray dst) {
+ if (dims_left == 1) {
+ return readOneDimensionalArray(env, data_type, src, src_size, dst);
+ } else {
+ jobjectArray ndarray = static_cast<jobjectArray>(dst);
+ int len = env->GetArrayLength(ndarray);
+ size_t size = 0;
+ for (int i = 0; i < len; ++i) {
+ jarray row = static_cast<jarray>(env->GetObjectArrayElement(ndarray, i));
+ size += readMultiDimensionalArray(env, data_type, src + size,
+ src_size - size, dims_left - 1, row);
+ env->DeleteLocalRef(row);
+ if (env->ExceptionCheck()) return size;
+ }
+ return size;
+ }
+}
+
+} // namespace
+
+size_t elementByteSize(TfLiteType data_type) {
+ // The code in this file makes the assumption that the
+ // TensorFlow TF_DataTypes and the Java primitive types
+ // have the same byte sizes. Validate that:
+ switch (data_type) {
+ case kTfLiteFloat32:
+ static_assert(sizeof(jfloat) == 4,
+ "Java float not compatible with kTfLiteFloat");
+ return 4;
+ case kTfLiteInt32:
+ static_assert(sizeof(jint) == 4,
+ "Java int not compatible with kTfLiteInt");
+ return 4;
+ case kTfLiteUInt8:
+ static_assert(sizeof(jbyte) == 1,
+ "Java byte not compatible with kTfLiteUInt8");
+ return 1;
+ case kTfLiteInt64:
+ static_assert(sizeof(jlong) == 8,
+ "Java long not compatible with kTfLiteInt64");
+ return 8;
+ default:
+ return 0;
+ }
+}
+
+size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size) {
+ char* buf = static_cast<char*>(env->GetDirectBufferAddress(object));
+ if (!buf) {
+ throwException(env, kIllegalArgumentException,
+ "Input ByteBuffer is not a direct buffer");
+ return 0;
+ }
+ *dst = buf;
+ return dst_size;
+}
+
+size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
+ int dims_left, char** dst, int dst_size) {
+ if (dims_left <= 1) {
+ return writeOneDimensionalArray(env, src, type, *dst, dst_size);
+ } else {
+ jobjectArray ndarray = static_cast<jobjectArray>(src);
+ int len = env->GetArrayLength(ndarray);
+ size_t sz = 0;
+ for (int i = 0; i < len; ++i) {
+ jobject row = env->GetObjectArrayElement(ndarray, i);
+ char* next_dst = *dst + sz;
+ sz += writeMultiDimensionalArray(env, row, type, dims_left - 1, &next_dst,
+ dst_size - sz);
+ env->DeleteLocalRef(row);
+ if (env->ExceptionCheck()) return sz;
+ }
+ return sz;
+ }
+}
+
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jobject value) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return;
+ int num_dims = tensor->dims->size;
+ if (num_dims == 0) {
+ throwException(env, kIllegalArgumentException,
+ "copyTo() is not meant for scalar Tensors.");
+ return;
+ }
+ readMultiDimensionalArray(env, tensor->type, tensor->data.raw, tensor->bytes,
+ num_dims, static_cast<jarray>(value));
+}
+
+JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return 0;
+ return static_cast<jint>(tensor->type);
+}
+
+JNIEXPORT jintArray JNICALL
+Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return nullptr;
+ int num_dims = tensor->dims->size;
+ jintArray result = env->NewIntArray(num_dims);
+ jint* dims = env->GetIntArrayElements(result, nullptr);
+ for (int i = 0; i < num_dims; ++i) {
+ dims[i] = static_cast<jint>(tensor->dims->data[i]);
+ }
+ env->ReleaseIntArrayElements(result, dims, 0);
+ return result;
+}
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
new file mode 100644
index 0000000000..3a4910dcc3
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_
+#define TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_
+
+#include <jni.h>
+#include "tensorflow/contrib/lite/context.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+/*
+ * Class: org_tensorflow_lite_TfLiteTensor
+ * Method:
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_TfLiteTensor
+ * Method:
+ * Signature: (J)[I
+ */
+JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_TfLiteTensor
+ * Method:
+ * Signature: (JLjava/lang/Object;)
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jobject value);
+
+/*
+ * Finds the size of each data type.
+ */
+size_t elementByteSize(TfLiteType data_type);
+
+/*
+ * Writes data of a ByteBuffer into dest.
+ */
+size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size);
+
+/*
+ * Writes a multi-dimensional array into dest.
+ */
+size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
+ int dims_left, char** dst, int dst_size);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc
new file mode 100644
index 0000000000..2e7f2f5692
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc
@@ -0,0 +1,26 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdio.h>
+
+#include "tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h"
+#include "tensorflow/contrib/lite/version.h"
+
+JNIEXPORT jstring JNICALL
+Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv* env, jclass /*clazz*/) {
+ char buf[64];
+ snprintf(buf, sizeof(buf), "%d", TFLITE_SCHEMA_VERSION);
+ return env->NewStringUTF(buf);
+}
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h
new file mode 100644
index 0000000000..65f8341149
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h
@@ -0,0 +1,36 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_
+#define TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_
+
+#include <jni.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+/*
+ * Class: org_tensorflow_lite_TensorFlowLite
+ * Method: version
+ * Signature: ()Ljava/lang/String;
+ */
+JNIEXPORT jstring JNICALL
+Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv*, jclass);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_
diff --git a/tensorflow/contrib/lite/java/src/main/native/version_script.lds b/tensorflow/contrib/lite/java/src/main/native/version_script.lds
new file mode 100644
index 0000000000..38c93dda73
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/version_script.lds
@@ -0,0 +1,11 @@
+VERS_1.0 {
+ # Export JNI symbols.
+ global:
+ Java_*;
+ JNI_OnLoad;
+ JNI_OnUnload;
+
+ # Hide everything else.
+ local:
+ *;
+};
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
new file mode 100644
index 0000000000..cebc944200
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link org.tensorflow.lite.DataType}. */
+@RunWith(JUnit4.class)
+public final class DataTypeTest {
+
+ @Test
+ public void testElemByteSize() {
+ assertThat(DataType.FLOAT32.elemByteSize()).isEqualTo(4);
+ assertThat(DataType.INT32.elemByteSize()).isEqualTo(4);
+ assertThat(DataType.UINT8.elemByteSize()).isEqualTo(1);
+ assertThat(DataType.INT64.elemByteSize()).isEqualTo(8);
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
new file mode 100644
index 0000000000..424b3de6c9
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -0,0 +1,221 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.fail;
+
+import java.io.File;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardOpenOption;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.Map;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link org.tensorflow.lite.Interpreter}. */
+@RunWith(JUnit4.class)
+public final class InterpreterTest {
+
+ private static final File MODEL_FILE =
+ new File("tensorflow/contrib/lite/java/src/testdata/add.bin");
+
+ private static final File MOBILENET_MODEL_FILE =
+ new File("tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin");
+
+ @Test
+ public void testInterpreter() throws Exception {
+ Interpreter interpreter = new Interpreter(MODEL_FILE);
+ assertThat(interpreter).isNotNull();
+ interpreter.close();
+ }
+
+ @Test
+ public void testRunWithMappedByteBufferModel() throws Exception {
+ Path path = MODEL_FILE.toPath();
+ FileChannel fileChannel =
+ (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
+ MappedByteBuffer mappedByteBuffer =
+ fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
+ Interpreter interpreter = new Interpreter(mappedByteBuffer);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ interpreter.run(fourD, parsedOutputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, 19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ interpreter.close();
+ fileChannel.close();
+ }
+
+ @Test
+ public void testRun() {
+ Interpreter interpreter = new Interpreter(MODEL_FILE);
+ Float[] oneD = {1.23f, 6.54f, 7.81f};
+ Float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ Float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ Float[][][][] fourD = {threeD, threeD};
+ Float[][][][] parsedOutputs = new Float[2][8][8][3];
+ try {
+ interpreter.run(fourD, parsedOutputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("cannot resolve DataType of [[[[Ljava.lang.Float;");
+ }
+ interpreter.close();
+ }
+
+ @Test
+ public void testRunWithBoxedInputs() {
+ Interpreter interpreter = new Interpreter(MODEL_FILE);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ interpreter.run(fourD, parsedOutputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, 19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ interpreter.close();
+ }
+
+ @Test
+ public void testRunForMultipleInputsOutputs() {
+ Interpreter interpreter = new Interpreter(MODEL_FILE);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ interpreter.runForMultipleInputsOutputs(inputs, outputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, 19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ interpreter.close();
+ }
+
+ @Test
+ public void testMobilenetRun() {
+ // Create a gray image.
+ float[][][][] img = new float[1][224][224][3];
+ for (int i = 0; i < 224; ++i) {
+ for (int j = 0; j < 224; ++j) {
+ img[0][i][j][0] = 0.5f;
+ img[0][i][j][1] = 0.5f;
+ img[0][i][j][2] = 0.5f;
+ }
+ }
+
+ // Allocate memory to receive the output values.
+ float[][] labels = new float[1][1001];
+
+ Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE);
+ interpreter.run(img, labels);
+ interpreter.close();
+
+ assertThat(labels[0])
+ .usingExactEquality()
+ .containsNoneOf(new float[] {Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY});
+ }
+
+ @Test
+ public void testRunWithWrongInputType() {
+ Interpreter interpreter = new Interpreter(MODEL_FILE);
+ int[] oneD = {4, 3, 9};
+ int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ int[][][][] fourD = {threeD, threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ try {
+ interpreter.run(fourD, parsedOutputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "DataType (2) of input data does not match with the DataType (1) of model inputs.");
+ }
+ interpreter.close();
+ }
+
+ @Test
+ public void testRunWithWrongOutputType() {
+ Interpreter interpreter = new Interpreter(MODEL_FILE);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ int[][][][] parsedOutputs = new int[2][8][8][3];
+ try {
+ interpreter.run(fourD, parsedOutputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "Cannot convert an TensorFlowLite tensor with type "
+ + "FLOAT32 to a Java object of type [[[[I (which is compatible with the"
+ + " TensorFlowLite type INT32)");
+ }
+ interpreter.close();
+ }
+
+ @Test
+ public void testGetInputIndex() {
+ Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE);
+ try {
+ interpreter.getInputIndex("WrongInputName");
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "WrongInputName is not a valid name for any input. The indexes of the inputs"
+ + " are {input=0}");
+ }
+ int index = interpreter.getInputIndex("input");
+ assertThat(index).isEqualTo(0);
+ }
+
+ @Test
+ public void testGetOutputIndex() {
+ Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE);
+ try {
+ interpreter.getOutputIndex("WrongOutputName");
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "WrongOutputName is not a valid name for any output. The indexes of the outputs"
+ + " are {MobilenetV1/Predictions/Softmax=0}");
+ }
+ int index = interpreter.getOutputIndex("MobilenetV1/Predictions/Softmax");
+ assertThat(index).isEqualTo(0);
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
new file mode 100644
index 0000000000..9a6894f49c
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -0,0 +1,406 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.fail;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link org.tensorflow.lite.NativeInterpreterWrapper}. */
+@RunWith(JUnit4.class)
+public final class NativeInterpreterWrapperTest {
+
+ private static final String FLOAT_MODEL_PATH =
+ "tensorflow/contrib/lite/java/src/testdata/add.bin";
+
+ private static final String INT_MODEL_PATH =
+ "tensorflow/contrib/lite/java/src/testdata/int32.bin";
+
+ private static final String LONG_MODEL_PATH =
+ "tensorflow/contrib/lite/java/src/testdata/int64.bin";
+
+ private static final String BYTE_MODEL_PATH =
+ "tensorflow/contrib/lite/java/src/testdata/uint8.bin";
+
+ private static final String INVALID_MODEL_PATH =
+ "tensorflow/contrib/lite/java/src/testdata/invalid_model.bin";
+
+ @Test
+ public void testConstructor() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ assertThat(wrapper).isNotNull();
+ wrapper.close();
+ }
+
+ @Test
+ public void testConstructorWithInvalidModel() {
+ try {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains("Model provided has model identifier ' is ', should be 'TFL3'");
+ }
+ }
+
+ @Test
+ public void testRunWithFloat() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ float[] oneD = {1.23f, -6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ Tensor[] outputs = wrapper.run(inputs);
+ assertThat(outputs.length).isEqualTo(1);
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ outputs[0].copyTo(parsedOutputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, -19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithInt() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH);
+ int[] oneD = {3, 7, -4};
+ int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ int[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ Tensor[] outputs = wrapper.run(inputs);
+ assertThat(outputs.length).isEqualTo(1);
+ int[][][][] parsedOutputs = new int[2][4][4][12];
+ outputs[0].copyTo(parsedOutputs);
+ int[] outputOneD = parsedOutputs[0][0][0];
+ int[] expected = {3, 7, -4, 3, 7, -4, 3, 7, -4, 3, 7, -4};
+ assertThat(outputOneD).isEqualTo(expected);
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithLong() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH);
+ long[] oneD = {-892834092L, 923423L, 2123918239018L};
+ long[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ long[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ long[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ Tensor[] outputs = wrapper.run(inputs);
+ assertThat(outputs.length).isEqualTo(1);
+ long[][][][] parsedOutputs = new long[2][4][4][12];
+ outputs[0].copyTo(parsedOutputs);
+ long[] outputOneD = parsedOutputs[0][0][0];
+ long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L,
+ -892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L};
+ assertThat(outputOneD).isEqualTo(expected);
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithByte() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH);
+ byte[] oneD = {(byte) 0xe0, 0x4f, (byte) 0xd0};
+ byte[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ byte[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ byte[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ int[] inputDims = {2, 8, 8, 3};
+ wrapper.resizeInput(0, inputDims);
+ Tensor[] outputs = wrapper.run(inputs);
+ assertThat(outputs.length).isEqualTo(1);
+ byte[][][][] parsedOutputs = new byte[2][4][4][12];
+ outputs[0].copyTo(parsedOutputs);
+ byte[] outputOneD = parsedOutputs[0][0][0];
+ byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
+ (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0};
+ assertThat(outputOneD).isEqualTo(expected);
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithByteBufferHavingBytes() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH);
+ ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 8 * 8 * 3);
+ bbuf.order(ByteOrder.nativeOrder());
+ bbuf.rewind();
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ for (int k = 0; k < 8; ++k) {
+ bbuf.put((byte) 0xe0);
+ bbuf.put((byte) 0x4f);
+ bbuf.put((byte) 0xd0);
+ }
+ }
+ }
+ Object[] inputs = {bbuf};
+ int[] inputDims = {2, 8, 8, 3};
+ wrapper.resizeInput(0, inputDims);
+ Tensor[] outputs = wrapper.run(inputs);
+ assertThat(outputs.length).isEqualTo(1);
+ byte[][][][] parsedOutputs = new byte[2][4][4][12];
+ outputs[0].copyTo(parsedOutputs);
+ byte[] outputOneD = parsedOutputs[0][0][0];
+ byte[] expected = {
+ (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
+ (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0
+ };
+ assertThat(outputOneD).isEqualTo(expected);
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithByteBufferHavingFloats() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ ByteBuffer bbuf = ByteBuffer.allocateDirect(4 * 8 * 8 * 3 * 4);
+ bbuf.order(ByteOrder.nativeOrder());
+ bbuf.rewind();
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ for (int k = 0; k < 8; ++k) {
+ bbuf.putFloat(1.23f);
+ bbuf.putFloat(-6.54f);
+ bbuf.putFloat(7.81f);
+ }
+ }
+ }
+ Object[] inputs = {bbuf};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "Failed to get input dimensions. 0-th input should have 768 bytes, but found 3072 bytes");
+ }
+ int[] inputDims = {4, 8, 8, 3};
+ wrapper.resizeInput(0, inputDims);
+ Tensor[] outputs = wrapper.run(inputs);
+ assertThat(outputs.length).isEqualTo(1);
+ float[][][][] parsedOutputs = new float[4][8][8][3];
+ outputs[0].copyTo(parsedOutputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, -19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithByteBufferHavingWrongSize() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH);
+ ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 7 * 8 * 3);
+ bbuf.order(ByteOrder.nativeOrder());
+ Object[] inputs = {bbuf};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "Failed to get input dimensions. 0-th input should have 192 bytes, but found 336 bytes.");
+ }
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithWrongInputType() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ int[] oneD = {4, 3, 9};
+ int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ int[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "DataType (2) of input data does not match with the DataType (1) of model inputs.");
+ }
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunAfterClose() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ wrapper.close();
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("Invalid handle to Interpreter.");
+ }
+ }
+
+ @Test
+ public void testRunWithEmptyInputs() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ try {
+ Object[] inputs = {};
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains("Invalid inputs. Inputs should not be null or empty.");
+ }
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithWrongInputSize() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD, fourD};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("Expected num of inputs is 1 but got 2");
+ }
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithWrongInputNumOfDims() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ Object[] inputs = {threeD};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains("0-th input should have 4 dimensions, but found 3 dimensions");
+ }
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithWrongInputDims() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]");
+ }
+ wrapper.close();
+ }
+
+ @Test
+ public void testNumElements() {
+ int[] shape = {2, 3, 4};
+ int num = NativeInterpreterWrapper.numElements(shape);
+ assertThat(num).isEqualTo(24);
+ shape = null;
+ num = NativeInterpreterWrapper.numElements(shape);
+ assertThat(num).isEqualTo(0);
+ }
+
+ @Test
+ public void testIsNonEmtpyArray() {
+ assertThat(NativeInterpreterWrapper.isNonEmptyArray(null)).isFalse();
+ assertThat(NativeInterpreterWrapper.isNonEmptyArray(3.2)).isFalse();
+ int[] emptyArray = {};
+ assertThat(NativeInterpreterWrapper.isNonEmptyArray(emptyArray)).isFalse();
+ int[] validArray = {9, 5, 2, 1};
+ assertThat(NativeInterpreterWrapper.isNonEmptyArray(validArray)).isTrue();
+ }
+
+ @Test
+ public void testDataTypeOf() {
+ float[] testEmtpyArray = {};
+ DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ float[] testFloatArray = {0.783f, 0.251f};
+ dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
+ dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ try {
+ double[] testDoubleArray = {0.783, 0.251};
+ NativeInterpreterWrapper.dataTypeOf(testDoubleArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
+ }
+ try {
+ Float[] testBoxedArray = {0.783f, 0.251f};
+ NativeInterpreterWrapper.dataTypeOf(testBoxedArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
+ }
+ }
+
+ @Test
+ public void testNumDimensions() {
+ int scalar = 1;
+ assertThat(NativeInterpreterWrapper.numDimensions(scalar)).isEqualTo(0);
+ int[][] array = {{2, 4}, {1, 9}};
+ assertThat(NativeInterpreterWrapper.numDimensions(array)).isEqualTo(2);
+ try {
+ int[] emptyArray = {};
+ NativeInterpreterWrapper.numDimensions(emptyArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("array lengths cannot be 0.");
+ }
+ }
+
+ @Test
+ public void testFillShape() {
+ int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
+ int num = NativeInterpreterWrapper.numDimensions(array);
+ int[] shape = new int[num];
+ NativeInterpreterWrapper.fillShape(array, 0, shape);
+ assertThat(num).isEqualTo(3);
+ assertThat(shape[0]).isEqualTo(2);
+ assertThat(shape[1]).isEqualTo(3);
+ assertThat(shape[2]).isEqualTo(1);
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java
new file mode 100644
index 0000000000..665c937cb6
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link org.tensorflow.lite.TensorFlowLite}. */
+@RunWith(JUnit4.class)
+public final class TensorFlowLiteTest {
+
+ @Test
+ public void testVersion() {
+ assertThat(TensorFlowLite.version()).isEqualTo("3");
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
new file mode 100644
index 0000000000..94b6632bb8
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -0,0 +1,105 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.fail;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link org.tensorflow.lite.Tensor}. */
+@RunWith(JUnit4.class)
+public final class TensorTest {
+
+ private static final String MODEL_PATH =
+ "tensorflow/contrib/lite/java/src/testdata/add.bin";
+
+ private NativeInterpreterWrapper wrapper;
+ private long nativeHandle;
+
+ @Before
+ public void setUp() {
+ wrapper = new NativeInterpreterWrapper(MODEL_PATH);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ Tensor[] outputs = wrapper.run(inputs);
+ nativeHandle = outputs[0].nativeHandle;
+ }
+
+ @After
+ public void tearDown() {
+ wrapper.close();
+ }
+
+ @Test
+ public void testFromHandle() throws Exception {
+ Tensor tensor = Tensor.fromHandle(nativeHandle);
+ assertThat(tensor).isNotNull();
+ int[] expectedShape = {2, 8, 8, 3};
+ assertThat(tensor.shapeCopy).isEqualTo(expectedShape);
+ assertThat(tensor.dtype).isEqualTo(DataType.FLOAT32);
+ }
+
+ @Test
+ public void testCopyTo() {
+ Tensor tensor = Tensor.fromHandle(nativeHandle);
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ tensor.copyTo(parsedOutputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, 19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ }
+
+ @Test
+ public void testCopyToWrongType() {
+ Tensor tensor = Tensor.fromHandle(nativeHandle);
+ int[][][][] parsedOutputs = new int[2][8][8][3];
+ try {
+ tensor.copyTo(parsedOutputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "Cannot convert an TensorFlowLite tensor with type "
+ + "FLOAT32 to a Java object of type [[[[I (which is compatible with the TensorFlowLite "
+ + "type INT32)");
+ }
+ }
+
+ @Test
+ public void testCopyToWrongShape() {
+ Tensor tensor = Tensor.fromHandle(nativeHandle);
+ float[][][][] parsedOutputs = new float[1][8][8][3];
+ try {
+ tensor.copyTo(parsedOutputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "Shape of output target [1, 8, 8, 3] does not match "
+ + "with the shape of the Tensor [2, 8, 8, 3].");
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/testdata/add.bin b/tensorflow/contrib/lite/java/src/testdata/add.bin
new file mode 100644
index 0000000000..aef0fe3d82
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/testdata/add.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/java/src/testdata/float32.bin b/tensorflow/contrib/lite/java/src/testdata/float32.bin
new file mode 100644
index 0000000000..30b1264ca1
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/testdata/float32.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/java/src/testdata/int32.bin b/tensorflow/contrib/lite/java/src/testdata/int32.bin
new file mode 100644
index 0000000000..f6f3cf607a
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/testdata/int32.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/java/src/testdata/int64.bin b/tensorflow/contrib/lite/java/src/testdata/int64.bin
new file mode 100644
index 0000000000..c12aa41ca7
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/testdata/int64.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/java/src/testdata/invalid_model.bin b/tensorflow/contrib/lite/java/src/testdata/invalid_model.bin
new file mode 100644
index 0000000000..8156ac741c
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/testdata/invalid_model.bin
@@ -0,0 +1 @@
+This is an invalid model. \ No newline at end of file
diff --git a/tensorflow/contrib/lite/java/src/testdata/uint8.bin b/tensorflow/contrib/lite/java/src/testdata/uint8.bin
new file mode 100644
index 0000000000..f06c5cf584
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/testdata/uint8.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD
new file mode 100644
index 0000000000..2b4f37bc6c
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD
@@ -0,0 +1,30 @@
+# Description:
+# Internal helper function to test TF Lite API.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+android_library(
+ name = "testhelper",
+ srcs = glob(
+ [
+ "*.java",
+ ],
+ ),
+ deps = [
+ "//tensorflow/contrib/lite/java:tensorflowlite_java",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
new file mode 100644
index 0000000000..8660cabf70
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
@@ -0,0 +1,35 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+/** A helper class for internal tests. */
+public class TestHelper {
+
+ /**
+ * Turns on/off NNAPI of an {@code Interpreter}.
+ *
+ * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code
+ * IllegalArgumentException} will be thrown.
+ * @param useNNAPI a boolean value indicating to turn on or off NNAPI.
+ */
+ public static void setUseNNAPI(Interpreter interpreter, boolean useNNAPI) {
+ if (interpreter != null && interpreter.wrapper != null) {
+ interpreter.wrapper.setUseNNAPI(useNNAPI);
+ } else {
+ throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI.");
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
new file mode 100644
index 0000000000..bbbfa3e741
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -0,0 +1,408 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+tf_cc_test(
+ name = "optional_tensor_test",
+ size = "small",
+ srcs = ["optional_tensor_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "test_util",
+ testonly = 1,
+ srcs = ["test_util.cc"],
+ hdrs = ["test_util.h"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/core:lib",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "gemm_support",
+ srcs = [
+ "gemm_support.cc",
+ ],
+ hdrs = [
+ "gemm_support.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ ":op_macros",
+ "//tensorflow/contrib/lite:context",
+ "@gemmlowp//:gemmlowp",
+ ],
+)
+
+cc_library(
+ name = "activation_functor",
+ hdrs = [
+ "activation_functor.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite:builtin_op_data",
+ ],
+)
+
+cc_library(
+ name = "op_macros",
+ hdrs = [
+ "op_macros.h",
+ ],
+)
+
+cc_library(
+ name = "builtin_ops",
+ srcs = [
+ "activations.cc",
+ "add.cc",
+ "basic_rnn.cc",
+ "concatenation.cc",
+ "conv.cc",
+ "depthwise_conv.cc",
+ "embedding_lookup.cc",
+ "embedding_lookup_sparse.cc",
+ "fully_connected.cc",
+ "hashtable_lookup.cc",
+ "kernel_util.cc",
+ "l2norm.cc",
+ "local_response_norm.cc",
+ "lsh_projection.cc",
+ "lstm.cc",
+ "mul.cc",
+ "pooling.cc",
+ "register.cc",
+ "reshape.cc",
+ "resize_bilinear.cc",
+ "skip_gram.cc",
+ "space_to_depth.cc",
+ "svdf.cc",
+ ],
+ hdrs = [
+ "kernel_util.h",
+ "padding.h",
+ "register.h",
+ ],
+ # Suppress warnings that are introduced by Eigen Tensor.
+ copts = tflite_copts() + [
+ "-Wno-error=reorder",
+ ] + select({
+ "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"],
+ "//conditions:default": [
+ ],
+ }),
+ deps = [
+ ":activation_functor",
+ ":op_macros",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:gemm_support",
+ "//tensorflow/contrib/lite/kernels/internal:optimized",
+ "//tensorflow/contrib/lite/kernels/internal:optimized_base",
+ "//tensorflow/contrib/lite/kernels/internal:quantization_util",
+ "//tensorflow/contrib/lite/kernels/internal:reference",
+ "//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:round",
+ "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
+ "@farmhash_archive//:farmhash",
+ ],
+)
+
+tf_cc_test(
+ name = "activations_test",
+ size = "small",
+ srcs = ["activations_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "add_test",
+ size = "small",
+ srcs = ["add_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "concatenation_test",
+ size = "small",
+ srcs = ["concatenation_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "conv_test",
+ size = "small",
+ srcs = ["conv_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "depthwise_conv_test",
+ size = "small",
+ srcs = ["depthwise_conv_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "basic_rnn_test",
+ size = "small",
+ srcs = ["basic_rnn_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "l2norm_test",
+ size = "small",
+ srcs = ["l2norm_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "mul_test",
+ size = "small",
+ srcs = ["mul_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "reshape_test",
+ size = "small",
+ srcs = ["reshape_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "resize_bilinear_test",
+ size = "small",
+ srcs = ["resize_bilinear_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "svdf_test",
+ size = "small",
+ srcs = ["svdf_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "embedding_lookup_test",
+ size = "small",
+ srcs = ["embedding_lookup_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "embedding_lookup_sparse_test",
+ size = "small",
+ srcs = ["embedding_lookup_sparse_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "fully_connected_test",
+ size = "small",
+ srcs = ["fully_connected_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "local_response_norm_test",
+ size = "small",
+ srcs = ["local_response_norm_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "pooling_test",
+ size = "small",
+ srcs = ["pooling_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "softmax_test",
+ size = "small",
+ srcs = ["softmax_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "lsh_projection_test",
+ size = "small",
+ srcs = ["lsh_projection_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "hashtable_lookup_test",
+ size = "small",
+ srcs = ["hashtable_lookup_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "lstm_test",
+ size = "small",
+ srcs = ["lstm_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "skip_gram_test",
+ size = "small",
+ srcs = ["skip_gram_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "space_to_depth_test",
+ size = "small",
+ srcs = ["space_to_depth_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h
new file mode 100644
index 0000000000..cfb3369e99
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/activation_functor.h
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
+
+#include <algorithm>
+#include <cmath>
+#include <cstdlib>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+
+namespace tflite {
+
+// Dynamic (non-fused) activation functor. perhaps it is worth having
+// template instantiation?
+// TODO(aselle): Make this more efficient by pulling the switch to conv_eval
+// using template inlining.
+class ActivationFunctor {
+ public:
+ explicit ActivationFunctor(TfLiteFusedActivation act) : act_(act) {}
+
+ float operator()(float a) const {
+ switch (act_) {
+ case kTfLiteActNone:
+ return a;
+ case kTfLiteActRelu:
+ return a < 0.f ? 0.f : a;
+ case kTfLiteActRelu6:
+ return std::max(0.f, std::min(a, 6.f));
+ case kTfLiteActTanh:
+ return std::tanh(a);
+ case kTfLiteActSigmoid:
+ return 1.0f / (1.0f + std::exp(-a));
+ default:
+ // TODO(aselle): More informative fatal error!
+ exit(1);
+ }
+ }
+
+ private:
+ TfLiteFusedActivation act_;
+};
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
new file mode 100644
index 0000000000..7ab60a33e5
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -0,0 +1,389 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdlib>
+#include <cstdio>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace activations {
+
+struct OpData {
+ int32_t input_multiplier = 0;
+ int input_left_shift = 0;
+ int32_t input_range_radius = 0;
+ int diff_min = 0;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ // This is a builtin op, so we don't use the contents in 'buffer', if any.
+ // Instead, we allocate a new object to carry information from Prepare() to
+ // Eval().
+ return new OpData;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ if (input->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+ TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
+
+ static constexpr int kInputIntegerBits = 4;
+
+ const double input_real_multiplier =
+ input->params.scale *
+ static_cast<double>(1 << (31 - kInputIntegerBits));
+
+ QuantizeMultiplierGreaterThanOne(input_real_multiplier,
+ &data->input_multiplier,
+ &data->input_left_shift);
+ data->input_range_radius =
+ CalculateInputRadius(kInputIntegerBits, data->input_left_shift);
+ }
+
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ TF_LITE_ENSURE(context,
+ NumDimensions(input) == 2 || NumDimensions(input) == 4);
+
+ if (input->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+ TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
+
+ static const int kScaledDiffIntegerBits = 5;
+
+ tflite::PreprocessSoftmaxScaling(
+ params->beta, input->params.scale, kScaledDiffIntegerBits,
+ &data->input_multiplier, &data->input_left_shift);
+ data->diff_min = -1.0 * tflite::CalculateInputRadius(
+ kScaledDiffIntegerBits, data->input_left_shift);
+ }
+
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ size_t elements = input->bytes / sizeof(float);
+ float* in = input->data.f;
+ float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; in++, out++) *out = std::max(0.f, *in);
+ return kTfLiteOk;
+ }
+ break;
+ default:
+ context->ReportError(context, "Only float32 supported currently.");
+ return kTfLiteError;
+ }
+}
+
+TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ size_t elements = input->bytes / sizeof(float);
+ float* in = input->data.f;
+ float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; in++, out++) {
+ *out = std::min(std::max(-1.f, *in), 1.f);
+ }
+ return kTfLiteOk;
+ } break;
+ default:
+ context->ReportError(context, "Only float32 supported currently.");
+ return kTfLiteError;
+ }
+}
+
+TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ size_t elements = input->bytes / sizeof(float);
+ float* in = input->data.f;
+ float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; in++, out++) *out = std::min(std::max(0.f, *in), 6.f);
+ return kTfLiteOk;
+ }
+ break;
+ default:
+ context->ReportError(context, "Only float32 supported currently.");
+ return kTfLiteError;
+ }
+}
+
+TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ size_t elements = input->bytes / sizeof(float);
+ float* in = input->data.f;
+ float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; in++, out++) *out = std::tanh(*in);
+ return kTfLiteOk;
+ }
+ break;
+ default:
+ context->ReportError(context, "Only float32 supported currently.");
+ return kTfLiteError;
+ }
+}
+
+// Sigmoid is also know as "Logistic".
+TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ size_t elements = input->bytes / sizeof(float);
+ float* in = input->data.f;
+ float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; in++, out++) *out = 1.f / (1.f + std::exp(-*in));
+ break;
+ }
+ case kTfLiteUInt8: {
+ optimized_ops::Logistic(
+ GetTensorData<uint8_t>(input), GetTensorDims(input),
+ input->params.zero_point, data->input_range_radius,
+ data->input_multiplier, data->input_left_shift,
+ GetTensorData<uint8_t>(output), GetTensorDims(output));
+ break;
+ }
+ default:
+ context->ReportError(context, "Only float32 supported currently.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// Takes a 2D tensor and perform softmax along the second dimension.
+void Softmax2DFloat(TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ const int batch_size = input->dims->data[0];
+ const int input_size = input->dims->data[1];
+ float* in = input->data.f;
+ float* out = output->data.f;
+ TF_LITE_ASSERT(input_size > 0);
+
+ // For each batch
+ for (int b = 0; b < batch_size; b++) {
+ // Find the max coeff.
+ float max_coeff = in[0];
+ for (int i = 1; i < input_size; i++) {
+ if (in[i] > max_coeff) max_coeff = in[i];
+ }
+
+ // Compute the normalized sum of exps.
+ float exp_sum = 0.0;
+ for (int i = 0; i < input_size; i++) {
+ out[i] = std::exp((in[i] - max_coeff) * params->beta);
+ exp_sum += out[i];
+ }
+
+ // Divide by the sum of exps.
+ float reciprocal_sum_exp = 1.f / exp_sum;
+ for (int i = 0; i < input_size; i++) {
+ out[i] *= reciprocal_sum_exp;
+ }
+
+ // Advance in and out pointers for the next batch.
+ in += input_size;
+ out += input_size;
+ }
+}
+
+void Softmax2DQuantized(TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ // TODO(ahentz): this is arguably a dirty trick. Since the implementation
+ // always traverses the last dimension of a 4D tensor, we will pretend our 2D
+ // tensor is 4D in a special way. We will convert a (X, Y) shape into a (X,
+ // 1, 1, Y) shape.
+ const int batch_size = input->dims->data[0];
+ const int input_size = input->dims->data[1];
+ optimized_ops::Softmax(GetTensorData<uint8_t>(input),
+ GetTensorDims({batch_size, 1, 1, input_size}),
+ data->input_multiplier, data->input_left_shift,
+ data->diff_min, GetTensorData<uint8_t>(output),
+ GetTensorDims({batch_size, 1, 1, input_size}));
+}
+
+// Takes a 4D tensor and perform softmax along the forth dimension.
+void Softmax4DFloat(TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ optimized_ops::Softmax(GetTensorData<float>(input), GetTensorDims(input),
+ params->beta, GetTensorData<float>(output),
+ GetTensorDims(output));
+}
+
+void Softmax4DQuantized(TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ optimized_ops::Softmax(GetTensorData<uint8_t>(input), GetTensorDims(input),
+ data->input_multiplier, data->input_left_shift,
+ data->diff_min, GetTensorData<uint8_t>(output),
+ GetTensorDims(output));
+}
+
+TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+
+ // TODO(ahentz): consider an implementation that works for many (all?)
+ // dimensions.
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ if (NumDimensions(input) == 2) {
+ Softmax2DFloat(input, output, params);
+ return kTfLiteOk;
+ }
+ if (NumDimensions(input) == 4) {
+ Softmax4DFloat(input, output, params);
+ return kTfLiteOk;
+ }
+ context->ReportError(context,
+ "Only 2D and 4D tensors supported currently.");
+ return kTfLiteError;
+ }
+ case kTfLiteUInt8: {
+ if (NumDimensions(input) == 2) {
+ Softmax2DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
+ if (NumDimensions(input) == 4) {
+ Softmax4DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
+ context->ReportError(context,
+ "Only 2D and 4D tensors supported currently.");
+ return kTfLiteError;
+ }
+ default:
+ context->ReportError(context,
+ "Only float32 and uint8_t supported currently.");
+ return kTfLiteError;
+ }
+}
+
+} // namespace activations
+
+TfLiteRegistration* Register_RELU() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ activations::GenericPrepare,
+ activations::ReluEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_RELU1() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ activations::GenericPrepare,
+ activations::Relu1Eval};
+ return &r;
+}
+
+TfLiteRegistration* Register_RELU6() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ activations::GenericPrepare,
+ activations::Relu6Eval};
+ return &r;
+}
+
+TfLiteRegistration* Register_TANH() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ activations::GenericPrepare,
+ activations::TanhEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_LOGISTIC() {
+ static TfLiteRegistration r = {activations::Init, activations::Free,
+ activations::SigmoidPrepare,
+ activations::SigmoidEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_SOFTMAX() {
+ static TfLiteRegistration r = {activations::Init, activations::Free,
+ activations::SoftmaxPrepare,
+ activations::SoftmaxEval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
new file mode 100644
index 0000000000..f10aee7017
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -0,0 +1,323 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseActivationsOpModel : public SingleOpModel {
+ public:
+ // Most activations don't take any options, so this constructor works for
+ // them.
+ BaseActivationsOpModel(BuiltinOperator type, TensorData input) {
+ input_ = AddInput(input);
+ if (input.type == TensorType_UINT8) {
+ output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
+ } else {
+ output_ = AddOutput({input.type, {}});
+ }
+ SetBuiltinOp(type, BuiltinOptions_NONE, 0);
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ // A dedicated constructor for SOFTMAX, which does some options.
+ BaseActivationsOpModel(float softmax_beta, TensorData input) {
+ input_ = AddInput(input);
+ if (input.type == TensorType_UINT8) {
+ output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
+ } else {
+ output_ = AddOutput({input.type, {}});
+ }
+ SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
+ CreateSoftmaxOptions(builder_, softmax_beta).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+class FloatActivationsOpModel : public BaseActivationsOpModel {
+ public:
+ using BaseActivationsOpModel::BaseActivationsOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+// TODO(ahentz): I don't quite understand the tradeoffs in the quantized
+// implementation of sigmoid and software, but a tolerance of twice the output
+// scale seems reasonable. We might want to change this if we have a better
+// theoretical bound.
+const float kQuantizedTolerance = 2 * (1. / 256);
+
+class QuantizedActivationsOpModel : public BaseActivationsOpModel {
+ public:
+ using BaseActivationsOpModel::BaseActivationsOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(input_, data);
+ }
+ std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+TEST(FloatActivationsOpTest, Relu) {
+ FloatActivationsOpModel m(BuiltinOperator_RELU,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 0, 0, 2, 4, //
+ 3, 0, 10, 1, //
+ }));
+}
+
+TEST(FloatActivationsOpTest, Relu1) {
+ FloatActivationsOpModel m(BuiltinOperator_RELU1,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+ m.SetInput({
+ 0.0, -0.6, 0.2, -0.4, //
+ 0.3, -2.0, 1.1, -0.1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 0.0, -0.6, 0.2, -0.4, //
+ 0.3, -1.0, 1.0, -0.1, //
+ }));
+}
+
+TEST(FloatActivationsOpTest, Relu6) {
+ FloatActivationsOpModel m(BuiltinOperator_RELU6,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 0, 0, 2, 4, //
+ 3, 0, 6, 1, //
+ }));
+}
+
+TEST(FloatActivationsOpTest, Tanh) {
+ FloatActivationsOpModel m(BuiltinOperator_TANH,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0, -0.9999877, 0.9640275, 0.999329, //
+ 0.99505475, -0.9640275, 1, 0.7615941, //
+ })));
+}
+
+TEST(FloatActivationsOpTest, Sigmoid) {
+ FloatActivationsOpModel m(BuiltinOperator_LOGISTIC,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.5, 0.002473, 0.880797, 0.982014, //
+ 0.952574, 0.119203, 0.999955, 0.731059, //
+ })));
+}
+
+TEST(QuantizedActivationsOpTest, Sigmoid) {
+ QuantizedActivationsOpModel m(
+ BuiltinOperator_LOGISTIC,
+ /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.5, 0.002473, 0.880797, 0.982014, //
+ 0.952574, 0.119203, 0.999955, 0.731059, //
+ },
+ kQuantizedTolerance)));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188}));
+}
+
+TEST(FloatActivationsOpTest, Softmax4D) {
+ FloatActivationsOpModel m(0.1,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 1, 4}});
+ m.SetInput({
+ 0, -6, 2, 4, // depth = 0
+ 3, -2, 10, 1, // depth = 1
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ })));
+
+ // Same input, but a different shape.
+ FloatActivationsOpModel m2(0.1,
+ /*input=*/{TensorType_FLOAT32, {4, 1, 1, 2}});
+ m2.SetInput({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ })));
+}
+
+TEST(QuantizedActivationsOpTest, Softmax4D) {
+ QuantizedActivationsOpModel m(
+ 0.1,
+ /*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10});
+ m.SetInput({
+ 0, -6, 2, 4, // depth = 0
+ 3, -2, 10, 1, // depth = 1
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ },
+ kQuantizedTolerance)));
+
+ // Same input, but a different shape.
+ QuantizedActivationsOpModel m2(
+ 0.1,
+ /*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10});
+ m2.SetInput({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
+}
+
+TEST(FloatActivationsOpTest, Softmax2D) {
+ FloatActivationsOpModel m(0.1,
+ /*input=*/{TensorType_FLOAT32, {2, 4}});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ })));
+
+ // Same input, but a different shape.
+ FloatActivationsOpModel m2(0.1,
+ /*input=*/{TensorType_FLOAT32, {4, 2}});
+ m2.SetInput({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ })));
+}
+
+TEST(QuantizedActivationsOpTest, Softmax2D) {
+ QuantizedActivationsOpModel m(0.1,
+ /*input=*/{TensorType_UINT8, {2, 4}, -10, 10});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ },
+ kQuantizedTolerance)));
+
+ // Same input, but a different shape.
+ QuantizedActivationsOpModel m2(0.1,
+ /*input=*/{TensorType_UINT8, {4, 2}, -10, 10});
+ m2.SetInput({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
new file mode 100644
index 0000000000..0e10a249ab
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -0,0 +1,184 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace add {
+
+// This file has three implementation of Add.
+enum KernelType {
+ kReference,
+ kGenericOptimized, // Neon-free
+ kNeonOptimized,
+};
+
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2));
+ for (int i = 0; i < NumDimensions(input1); ++i) {
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i),
+ SizeOfDimension(input2, i));
+ }
+
+ TF_LITE_ENSURE_EQ(context, input1->type, output->type);
+ TF_LITE_ENSURE_EQ(context, input2->type, output->type);
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims);
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <KernelType kernel_type>
+void EvalAddFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLiteAddParams* params, TfLiteTensor* input1,
+ TfLiteTensor* input2, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRangeFloat(params->activation, &output_activation_min,
+ &output_activation_max);
+#define TF_LITE_ADD(type) \
+ type::Add(GetTensorData<float>(input1), GetTensorDims(input1), \
+ GetTensorData<float>(input2), GetTensorDims(input2), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<float>(output), GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_ADD(reference_ops);
+ } else {
+ TF_LITE_ADD(optimized_ops);
+ }
+#undef TF_LITE_ADD
+}
+
+template <KernelType kernel_type>
+void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteAddParams* params, TfLiteTensor* input1,
+ TfLiteTensor* input2, TfLiteTensor* output) {
+ auto input1_offset = -input1->params.zero_point;
+ auto input2_offset = -input2->params.zero_point;
+ auto output_offset = output->params.zero_point;
+ const int left_shift = 20;
+ const double twice_max_input_scale =
+ 2 * std::max(input1->params.scale, input2->params.scale);
+ const double real_input1_multiplier =
+ input1->params.scale / twice_max_input_scale;
+ const double real_input2_multiplier =
+ input2->params.scale / twice_max_input_scale;
+ const double real_output_multiplier =
+ twice_max_input_scale / ((1 << left_shift) * output->params.scale);
+
+ int32 input1_multiplier;
+ int input1_shift;
+ QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
+ &input1_shift);
+ int32 input2_multiplier;
+ int input2_shift;
+ QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
+ &input2_shift);
+ int32 output_multiplier;
+ int output_shift;
+ QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
+ &output_shift);
+
+ int32 output_activation_min, output_activation_max;
+ CalculateActivationRangeUint8(params->activation, output,
+ &output_activation_min, &output_activation_max);
+
+#define TF_LITE_ADD(type) \
+ type::BroadcastAdd( \
+ left_shift, GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
+ input1_offset, input1_multiplier, input1_shift, \
+ GetTensorData<uint8_t>(input2), GetTensorDims(input2), input2_offset, \
+ input2_multiplier, input2_shift, output_offset, output_multiplier, \
+ output_shift, output_activation_min, output_activation_max, \
+ GetTensorData<uint8_t>(output), GetTensorDims(output));
+
+ if (kernel_type == kReference) {
+ TF_LITE_ADD(reference_ops);
+ } else {
+ TF_LITE_ADD(optimized_ops);
+ }
+#undef TF_LITE_ADD
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
+
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (output->type == kTfLiteFloat32) {
+ EvalAddFloat<kernel_type>(context, node, params, input1, input2, output);
+ } else if (output->type == kTfLiteUInt8) {
+ EvalAddQuantized<kernel_type>(context, node, params, input1, input2,
+ output);
+ } else {
+ context->ReportError(context,
+ "Inputs and outputs not all float|unit8 types.");
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace add
+
+TfLiteRegistration* Register_ADD_REF() {
+ static TfLiteRegistration r = {nullptr, nullptr, add::Prepare,
+ add::Eval<add::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_ADD_GENERIC_OPT() {
+ static TfLiteRegistration r = {nullptr, nullptr, add::Prepare,
+ add::Eval<add::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_ADD_NEON_OPT() {
+ static TfLiteRegistration r = {nullptr, nullptr, add::Prepare,
+ add::Eval<add::kNeonOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_ADD() {
+#ifdef USE_NEON
+ return Register_ADD_NEON_OPT();
+#else
+ return Register_ADD_GENERIC_OPT();
+#endif
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc
new file mode 100644
index 0000000000..8e12a837c4
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/add_test.cc
@@ -0,0 +1,171 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseAddOpModel : public SingleOpModel {
+ public:
+ BaseAddOpModel(const TensorData& input, const TensorData& output,
+ ActivationFunctionType activation_type) {
+ input1_ = AddInput(input);
+ input2_ = AddInput(input);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
+ CreateAddOptions(builder_, activation_type).Union());
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ protected:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+class FloatAddOpModel : public BaseAddOpModel {
+ public:
+ using BaseAddOpModel::BaseAddOpModel;
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+class QuantizedAddOpModel : public BaseAddOpModel {
+ public:
+ using BaseAddOpModel::BaseAddOpModel;
+
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+// for quantized Add, the error shouldn't exceed 2*step
+float GetTolerance(int min, int max) {
+ float kQuantizedStep = (max - min) / 255.0;
+ float kQuantizedTolerance = 2.0 * kQuantizedStep;
+ return kQuantizedTolerance;
+}
+
+TEST(FloatAddOpModel, NoActivation) {
+ FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3}));
+}
+
+TEST(FloatAddOpModel, ActivationRELU1) {
+ FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.0, 0.4, 1.0, 1.0}));
+}
+
+TEST(FloatAddOpModel, VariousInputShapes) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({-1.9, 0.4, 1.0, 1.3, 2.2, 2.1}))
+ << "With shape number " << i;
+ }
+}
+
+TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<std::initializer_list<float>> inputs1 = {
+ {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, 0.7, 0.3}};
+ std::vector<std::initializer_list<float>> inputs2 = {
+ {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}};
+ std::vector<std::initializer_list<float>> results = {
+ {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}};
+ for (int i = 0; i < inputs1.size(); ++i) {
+ QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<uint8_t>(m.input1(), inputs1[i]);
+ m.QuantizeAndPopulate<uint8_t>(m.input2(), inputs2[i]);
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ results[i], kQuantizedTolerance)))
+ << "With test number " << i;
+ }
+}
+
+TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU1) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<std::initializer_list<float>> inputs1 = {{-0.8, 0.2, 0.9, 0.7},
+ {-0.8, 0.2, 0.7, 0.3}};
+ std::vector<std::initializer_list<float>> inputs2 = {{0.6, 0.4, 0.9, -0.8},
+ {0.6, 0.4, -0.8, 0.5}};
+ std::vector<std::initializer_list<float>> results = {{-0.2, 0.6, 1.0, -0.1},
+ {-0.2, 0.6, -0.1, 0.8}};
+ for (int i = 0; i < inputs1.size(); ++i) {
+ QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0},
+ ActivationFunctionType_RELU1);
+ m.QuantizeAndPopulate<uint8_t>(m.input1(), inputs1[i]);
+ m.QuantizeAndPopulate<uint8_t>(m.input2(), inputs2[i]);
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ results[i], kQuantizedTolerance)))
+ << "With test number " << i;
+ }
+}
+
+TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) {
+ float kQuantizedTolerance = GetTolerance(-3.0, 3.0);
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0},
+ {TensorType_UINT8, {}, -3.0, 3.0},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<uint8_t>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
+ m.QuantizeAndPopulate<uint8_t>(m.input2(), {0.1, 0.3, 0.3, 0.5, 1.1, 0.1});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({-1.9, 0.5, 1.0, 1.3, 2.2, 2.1},
+ kQuantizedTolerance)))
+ << "With shape number " << i;
+ }
+}
+
+} // namespace
+} // namespace tflite
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc
new file mode 100644
index 0000000000..3cee43c68b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc
@@ -0,0 +1,161 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdlib>
+#include <cstdio>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace rnn {
+
+constexpr int kInputTensor = 0;
+constexpr int kWeightsTensor = 1;
+constexpr int kRecurrentWeightsTensor = 2;
+constexpr int kBiasTensor = 3;
+constexpr int KHiddenStateTensor = 0;
+constexpr int kOutputTensor = 1;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // Check we have all the inputs and outputs we need.
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+
+ TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
+ TfLiteTensor* input_weights =
+ &context->tensors[node->inputs->data[kWeightsTensor]];
+ TfLiteTensor* recurrent_weights =
+ &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
+ TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
+
+ // Check all the parameters of tensor match within themselves and match the
+ // input configuration.
+ const int batch_size = input->dims->data[0];
+ const int num_units = input_weights->dims->data[0];
+ TF_LITE_ASSERT_EQ(input->dims->data[1], input_weights->dims->data[1]);
+ TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]);
+ TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]);
+ TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
+
+ TfLiteTensor* hidden_state =
+ &context->tensors[node->outputs->data[KHiddenStateTensor]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
+
+ // Resize state.
+ TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
+ hidden_state_size_array->data[0] = batch_size;
+ hidden_state_size_array->data[1] = num_units;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state,
+ hidden_state_size_array));
+
+ // Mark hidden state as a persistent tensor.
+ hidden_state->allocation_type = kTfLiteArenaRwPersistent;
+
+ // Resize output.
+ TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
+ output_size_array->data[0] = batch_size;
+ output_size_array->data[1] = num_units;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output,
+ output_size_array));
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
+
+ TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
+ TfLiteTensor* input_weights =
+ &context->tensors[node->inputs->data[kWeightsTensor]];
+ TfLiteTensor* recurrent_weights =
+ &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
+ TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
+ TfLiteTensor* hidden_state =
+ &context->tensors[node->outputs->data[KHiddenStateTensor]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
+
+ // Initialize the pointer bias.
+ const float* bias_ptr = bias->data.f;
+
+ const int batch_size = input->dims->data[0];
+ const int num_units = input_weights->dims->data[0];
+ const int input_size = input->dims->data[1];
+ const int input_weights_stride = input_weights->dims->data[1];
+ const int recurrent_weights_stride = recurrent_weights->dims->data[1];
+
+ // For each batch
+ for (int b = 0; b < batch_size; b++) {
+ // Initialize the pointer to input, output and bias.
+ const float* input_ptr_batch = input->data.f + b * input_size;
+ float* output_ptr_batch = output->data.f + b * num_units;
+ float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
+
+ // Initialize input_weights and recurrent_weights.
+ const float* input_weights_ptr = input_weights->data.f;
+ const float* recurrent_weights_ptr = recurrent_weights->data.f;
+
+ // Output = bias
+ for (int o = 0; o < num_units; o++) {
+ output_ptr_batch[o] = bias_ptr[o];
+ }
+
+ // Output += input * input_weights
+ for (int o = 0; o < num_units; o++) {
+ for (int i = 0; i < input_size; i++) {
+ output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
+ }
+ input_weights_ptr += input_weights_stride;
+ }
+
+ // Output += recurrent_weights * hidden_state
+ for (int o = 0; o < num_units; o++) {
+ for (int h = 0; h < num_units; h++) {
+ output_ptr_batch[o] +=
+ hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
+ }
+ recurrent_weights_ptr += recurrent_weights_stride;
+ }
+
+ // Output = activation(Output) and update hidden_state
+ for (int o = 0; o < num_units; o++) {
+ output_ptr_batch[o] =
+ (ActivationFunctor(params->activation))(output_ptr_batch[o]);
+ hidden_state_ptr_batch[o] = output_ptr_batch[o];
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace rnn
+
+TfLiteRegistration* Register_RNN() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ rnn::Prepare, rnn::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
new file mode 100644
index 0000000000..dfa75655bc
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
@@ -0,0 +1,267 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite RNN op.
+
+#include <vector>
+#include <iomanip>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+static float rnn_input[] = {
+ 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133,
+ 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471,
+ -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222,
+ 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933,
+ 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103,
+ 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043,
+ -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007,
+ -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154,
+ 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584,
+ 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144,
+ 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351,
+ -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719,
+ 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567,
+ -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881,
+ -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032,
+ -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374,
+ 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071,
+ -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219,
+ -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682,
+ 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493,
+ -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265,
+ 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539,
+ 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446,
+ 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017,
+ -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563,
+ 0.93455386, -0.6324693, -0.083922029};
+
+static float rnn_golden_output[] = {
+ 0.496726, 0, 0.965996, 0, 0.0584254, 0,
+ 0, 0.12315, 0, 0, 0.612266, 0.456601,
+ 0, 0.52286, 1.16099, 0.0291232,
+
+ 0, 0, 0.524901, 0, 0, 0,
+ 0, 1.02116, 0, 1.35762, 0, 0.356909,
+ 0.436415, 0.0355727, 0, 0,
+
+ 0, 0, 0, 0.262335, 0, 0,
+ 0, 1.33992, 0, 2.9739, 0, 0,
+ 1.31914, 2.66147, 0, 0,
+
+ 0.942568, 0, 0, 0, 0.025507, 0,
+ 0, 0, 0.321429, 0.569141, 1.25274, 1.57719,
+ 0.8158, 1.21805, 0.586239, 0.25427,
+
+ 1.04436, 0, 0.630725, 0, 0.133801, 0.210693,
+ 0.363026, 0, 0.533426, 0, 1.25926, 0.722707,
+ 0, 1.22031, 1.30117, 0.495867,
+
+ 0.222187, 0, 0.72725, 0, 0.767003, 0,
+ 0, 0.147835, 0, 0, 0, 0.608758,
+ 0.469394, 0.00720298, 0.927537, 0,
+
+ 0.856974, 0.424257, 0, 0, 0.937329, 0,
+ 0, 0, 0.476425, 0, 0.566017, 0.418462,
+ 0.141911, 0.996214, 1.13063, 0,
+
+ 0.967899, 0, 0, 0, 0.0831304, 0,
+ 0, 1.00378, 0, 0, 0, 1.44818,
+ 1.01768, 0.943891, 0.502745, 0,
+
+ 0.940135, 0, 0, 0, 0, 0,
+ 0, 2.13243, 0, 0.71208, 0.123918, 1.53907,
+ 1.30225, 1.59644, 0.70222, 0,
+
+ 0.804329, 0, 0.430576, 0, 0.505872, 0.509603,
+ 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311,
+ 0.0454298, 0.300267, 0.562784, 0.395095,
+
+ 0.228154, 0, 0.675323, 0, 1.70536, 0.766217,
+ 0, 0, 0, 0.735363, 0.0759267, 1.91017,
+ 0.941888, 0, 0, 0,
+
+ 0, 0, 1.5909, 0, 0, 0,
+ 0, 0.5755, 0, 0.184687, 0, 1.56296,
+ 0.625285, 0, 0, 0,
+
+ 0, 0, 0.0857888, 0, 0, 0,
+ 0, 0.488383, 0.252786, 0, 0, 0,
+ 1.02817, 1.85665, 0, 0,
+
+ 0.00981836, 0, 1.06371, 0, 0, 0,
+ 0, 0, 0, 0.290445, 0.316406, 0,
+ 0.304161, 1.25079, 0.0707152, 0,
+
+ 0.986264, 0.309201, 0, 0, 0, 0,
+ 0, 1.64896, 0.346248, 0, 0.918175, 0.78884,
+ 0.524981, 1.92076, 2.07013, 0.333244,
+
+ 0.415153, 0.210318, 0, 0, 0, 0,
+ 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453,
+ 0.628881, 3.58099, 1.49974, 0
+};
+
+class RNNOpModel : public SingleOpModel {
+ public:
+ RNNOpModel(int batches, int units, int size)
+ : batches_(batches), units_(units), input_size_(size) {
+ input_ = AddInput(TensorType_FLOAT32);
+ weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_weights_ = AddInput(TensorType_FLOAT32);
+ bias_ = AddInput(TensorType_FLOAT32);
+ hidden_state_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(
+ BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
+ CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
+ BuildInterpreter({{batches_, input_size_},
+ {units_, input_size_},
+ {units_, units_},
+ {units_}});
+ }
+
+ void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+
+ void SetWeights(std::initializer_list<float> f) {
+ PopulateTensor(weights_, f);
+ }
+
+ void SetRecurrentWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_weights_, f);
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ void ResetHiddenState() {
+ const int zero_buffer_size = units_ * batches_;
+ std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
+ memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
+ PopulateTensor(hidden_state_, 0, zero_buffer.get(),
+ zero_buffer.get() + zero_buffer_size);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ int input_size() { return input_size_; }
+ int num_units() { return units_; }
+ int num_batches() { return batches_; }
+
+ private:
+ int input_;
+ int weights_;
+ int recurrent_weights_;
+ int bias_;
+ int hidden_state_;
+ int output_;
+
+ int batches_;
+ int units_;
+ int input_size_;
+};
+
+TEST(FullyConnectedOpTest, BlackBoxTest) {
+ RNNOpModel rnn(2, 16, 8);
+ rnn.SetWeights(
+ {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
+ 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
+ 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
+ -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
+ -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
+ -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
+ -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
+ 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
+ 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
+ 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
+ -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
+ 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
+ -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
+ -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
+ 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
+ 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
+ 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
+ -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
+ 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
+ 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
+ -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
+ 0.277308, 0.415818});
+
+ rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
+ -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
+ 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
+ -0.37609905});
+
+ rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1});
+
+ rnn.ResetHiddenState();
+ const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
+ (rnn.input_size() * rnn.num_batches());
+
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch_start = rnn_input + i * rnn.input_size();
+ float* batch_end = batch_start + rnn.input_size();
+ rnn.SetInput(0, batch_start, batch_end);
+ rnn.SetInput(rnn.input_size(), batch_start, batch_end);
+
+ rnn.Invoke();
+
+ float* golden_start = rnn_golden_output + i * rnn.num_units();
+ float* golden_end = golden_start + rnn.num_units();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc
new file mode 100644
index 0000000000..9e7a1233da
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/concatenation.cc
@@ -0,0 +1,200 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace concatenation {
+
+// This file has two implementation of Concatenation.
+enum KernelType {
+ kReference,
+ kGenericOptimized,
+};
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
+ int axis = params->axis;
+ int num_inputs = node->inputs->size;
+
+ // The number of dimensions of the input tensors must match, and all
+ // dimensions except 'axis' must be equal.
+ TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]];
+ TfLiteType input_type = t0->type;
+ TF_LITE_ENSURE(context, axis >= 0);
+ TF_LITE_ENSURE(context, axis < t0->dims->size);
+
+ // TODO(ahentz): These are limitations of our implementation that could be
+ // removed with a bit of effort.
+ TF_LITE_ENSURE(context, t0->dims->size <= 4);
+ TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
+ TF_LITE_ENSURE(context,
+ input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8);
+
+ // Output dimensions will match input dimensions, except 'axis', which
+ // will be the sum of inputs
+ int sum_axis = t0->dims->data[axis];
+ for (int i = 1; i < num_inputs; ++i) {
+ TfLiteTensor* t = &context->tensors[node->inputs->data[i]];
+ TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
+ TF_LITE_ENSURE_EQ(context, t->type, input_type);
+ if (input_type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, t->params.zero_point, t0->params.zero_point);
+ TF_LITE_ENSURE_EQ(context, t->params.scale, t0->params.scale);
+ }
+ for (int d = 0; d < t0->dims->size; ++d) {
+ if (d == axis) {
+ sum_axis += t->dims->data[axis];
+ } else {
+ TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]);
+ }
+ }
+ }
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size);
+ for (int d = 0; d < t0->dims->size; ++d) {
+ output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d];
+ }
+
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ TF_LITE_ENSURE_EQ(context, output->type, input_type);
+ if (input_type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point,
+ t0->params.zero_point);
+ TF_LITE_ENSURE_EQ(context, output->params.scale, t0->params.scale);
+ }
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <typename T>
+class VectorOfInputs {
+ public:
+ VectorOfInputs(const TfLiteContext& context, const TfLiteIntArray& inputs) {
+ int num_inputs = inputs.size;
+
+ all_data_.reserve(num_inputs);
+ all_dims_.reserve(num_inputs);
+ all_dims_ptr_.reserve(num_inputs);
+
+ for (int i = 0; i < num_inputs; ++i) {
+ TfLiteTensor* input = &context.tensors[inputs.data[i]];
+ all_data_.push_back(GetTensorData<T>(input));
+ all_dims_.push_back(GetTensorDims(input));
+ }
+
+ // Taking the pointer from inside a std::vector is only OK if the vector is
+ // never modified, so we populate all_dims in the previous loop and then we
+ // are free to grab iterators here.
+ for (int i = 0; i < num_inputs; ++i) {
+ all_dims_ptr_.push_back(&all_dims_[i]);
+ }
+ }
+ const T* const* data() const { return all_data_.data(); }
+ const Dims<4>* const* dims() const { return all_dims_ptr_.data(); }
+
+ private:
+ std::vector<T*> all_data_;
+ std::vector<Dims<4>> all_dims_;
+ std::vector<Dims<4>*> all_dims_ptr_;
+};
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
+
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+
+// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
+// allocate and populate these during Prepare().
+// TODO(ycling): Activation function parameter is ignored. For now we dont have
+// a model with a Concatenation with fused activation function.
+#define TF_LITE_CONCATENATION(type, scalar) \
+ VectorOfInputs<scalar> all_inputs(*context, *node->inputs); \
+ type::Concatenation<FusedActivationFunctionType::kNone, scalar>( \
+ RemapDim(NumDimensions(output), params->axis), all_inputs.data(), \
+ all_inputs.dims(), node->inputs->size, GetTensorData<scalar>(output), \
+ GetTensorDims(output))
+
+ switch (output->type) { // Already know in/outtypes are same.
+ case kTfLiteFloat32:
+ if (kernel_type == kReference) {
+ TF_LITE_CONCATENATION(reference_ops, float);
+ } else {
+ TF_LITE_CONCATENATION(optimized_ops, float);
+ }
+ break;
+ case kTfLiteUInt8:
+ if (kernel_type == kReference) {
+ TF_LITE_CONCATENATION(reference_ops, uint8_t);
+ } else {
+ TF_LITE_CONCATENATION(optimized_ops, uint8_t);
+ }
+ break;
+ default:
+ context->ReportError(context,
+ "Only float32 and uint8 are currently supported.");
+ return kTfLiteError;
+ }
+
+#undef TF_LITE_CONCATENATION
+
+ return kTfLiteOk;
+}
+
+#undef TF_LITE_MACRO_DISPATCH
+
+} // namespace concatenation
+
+TfLiteRegistration* Register_CONCATENATION_REF() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, concatenation::Prepare,
+ concatenation::Eval<concatenation::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, concatenation::Prepare,
+ concatenation::Eval<concatenation::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_CONCATENATION() {
+ // TODO(ahentz): It turns out the two versions of Concatenation are almost
+ // identical, so we should consider removing one.
+ return Register_CONCATENATION_GENERIC_OPT();
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/concatenation_test.cc b/tensorflow/contrib/lite/kernels/concatenation_test.cc
new file mode 100644
index 0000000000..94e5b2acdc
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/concatenation_test.cc
@@ -0,0 +1,162 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseConcatenationOpModel : public SingleOpModel {
+ public:
+ // TODO(ahentz): Also test different activation types, axis, input
+ // dimensions.
+ BaseConcatenationOpModel(const TensorData& input_template, int axis,
+ int num_inputs) {
+ std::vector<std::vector<int>> all_input_shapes;
+ for (int i = 0; i < num_inputs; ++i) {
+ all_input_shapes.push_back(input_template.shape);
+ AddInput(input_template);
+ }
+ output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min,
+ input_template.max});
+ SetBuiltinOp(
+ BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
+ CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
+ .Union());
+ BuildInterpreter(all_input_shapes);
+ }
+
+ protected:
+ int output_;
+};
+
+class ConcatenationOpModel : public BaseConcatenationOpModel {
+ public:
+ using BaseConcatenationOpModel::BaseConcatenationOpModel;
+ void SetInput(int index, std::initializer_list<float> data) {
+ PopulateTensor(index, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+class QuantizedConcatenationOpModel : public BaseConcatenationOpModel {
+ public:
+ using BaseConcatenationOpModel::BaseConcatenationOpModel;
+ void SetInput(int index, std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(index, data);
+ }
+ std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+TEST(ConcatenationOpTest, ThreeDimensionalOneInput) {
+ ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1,
+ /*num_inputs=*/1);
+ m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
+ m0.Invoke();
+ EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7}));
+}
+
+TEST(ConcatenationOpTest, OneTrivialInput) {
+ ConcatenationOpModel m0({TensorType_FLOAT32, {1}}, /*axis=*/0,
+ /*num_inputs=*/1);
+ m0.SetInput(0, {5.0f});
+ m0.Invoke();
+ EXPECT_THAT(m0.GetOutput(), ::testing::ElementsAre(5));
+}
+
+TEST(ConcatenationOpTest, TwoDimensionalOneInput) {
+ ConcatenationOpModel m0({TensorType_FLOAT32, {2, 3}}, /*axis=*/0,
+ /*num_inputs=*/1);
+ m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+ m0.Invoke();
+ EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(ConcatenationOpTest, TwoInputsTwoAxis) {
+ // We will concatenate two tensors along different dimensions.
+ auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
+
+ ConcatenationOpModel m0({TensorType_FLOAT32, {2, 3}}, /*axis=*/0,
+ /*num_inputs=*/2);
+ m0.SetInput(0, tensor0);
+ m0.SetInput(1, tensor1);
+ m0.Invoke();
+ EXPECT_THAT(m0.GetOutput(),
+ ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
+
+ ConcatenationOpModel m1({TensorType_FLOAT32, {2, 3}}, /*axis=*/1,
+ /*num_inputs=*/2);
+ m1.SetInput(0, tensor0);
+ m1.SetInput(1, tensor1);
+ m1.Invoke();
+ EXPECT_THAT(m1.GetOutput(),
+ ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
+}
+
+TEST(ConcatenationOpTest, FourInputs) {
+ ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2,
+ /*num_inputs=*/4);
+ m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
+ m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
+ m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
+ m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
+ m0.Invoke();
+ EXPECT_THAT(m0.GetOutput(),
+ ElementsAreArray({
+ 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, //
+ 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, //
+ }));
+}
+
+TEST(ConcatenationOpTest, FourInputsQuantized) {
+ QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8},
+ /*axis=*/2,
+ /*num_inputs=*/4);
+
+ m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
+ m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
+ m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
+ m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
+ m0.Invoke();
+ EXPECT_THAT(m0.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({
+ 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, //
+ 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, //
+ })));
+ EXPECT_THAT(m0.GetOutput(), ElementsAreArray({
+ 137, 157, 138, 158, 139, 159, 140, 160, //
+ 167, 197, 168, 198, 169, 199, 170, 200, //
+ }));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
new file mode 100644
index 0000000000..c75c04baea
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -0,0 +1,425 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/gemm_support.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/kernels/padding.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace conv {
+
+// This file has three implementation of Conv.
+enum KernelType {
+ kReference,
+ kGenericOptimized, // Neon-free
+ kNeonOptimized,
+};
+
+struct OpData {
+ // IDs are the arbitrary identifiers used by TF Lite to identify and access
+ // memory buffers.
+ int im2col_id;
+ int hwcn_weights_id;
+
+ TfLitePaddingValues padding;
+ // The scaling factor from input to output (aka the 'real multiplier') can
+ // be represented as a fixed point multipler plus a left shift.
+ int32_t output_multiplier;
+ int output_shift;
+ // The range of the fused activation layer. For example for kNone and
+ // uint8_t these would be 0 and 255.
+ int32_t output_activation_min;
+ int32_t output_activation_max;
+ // Indexes are the offset to the memory buffer in the array used to keep track
+ // of the allocated temporaries.
+ int32_t im2col_index;
+ int32_t hwcn_weights_index;
+ bool need_hwcn_weights;
+ bool have_weights_been_transposed;
+ bool need_im2col;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ // This is a builtin op, so we don't use the contents in 'buffer', if any.
+ // Instead, we allocate a new object to use as scratch space for im2col, and
+ // to carry information from Prepare() to Eval().
+ auto* data = new OpData;
+ context->AddTensors(context, 1, &data->im2col_id);
+ context->AddTensors(context, 1, &data->hwcn_weights_id);
+ gemm_support::IncrementUsageCounter(context);
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ gemm_support::DecrementUsageCounter(context);
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+// Naive implementation of transpose for floats. Could be optimized to be more
+// cache friendly, but for now it's a one-time cost on first run, and we would
+// prefer to remove the need to do this at all eventually.
+void TransposeFloatTensor(TfLiteTensor* input, TfLiteTensor* output) {
+ const int rows = output->dims->data[1];
+ const int cols = output->dims->data[0];
+ const float* input_data = GetTensorData<float>(input);
+ float* output_data = GetTensorData<float>(output);
+ for (int i = 0; i < rows; ++i) {
+ for (int j = 0; j < cols; ++j) {
+ const float in_value = input_data[i * cols + j];
+ output_data[j * rows + i] = in_value;
+ }
+ }
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ bool hasBias = node->inputs->size == 3;
+ // Check number of inputs/outputs
+ TF_LITE_ENSURE(context, hasBias || node->inputs->size == 2);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* filter = &context->tensors[node->inputs->data[1]];
+ // Check dimensionality of input, filter
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
+ TF_LITE_ENSURE_EQ(context, filter->dims->size, 4);
+ // Check input channels matching filter
+ TF_LITE_ENSURE_EQ(context, input->dims->data[3], filter->dims->data[3]);
+
+ // Check types. (We assume that UINT8 refers to quantized tensors)
+ TfLiteType data_type = input->type;
+ TF_LITE_ENSURE(context,
+ data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8);
+ TF_LITE_ENSURE_EQ(context, output->type, data_type);
+ TF_LITE_ENSURE_EQ(context, filter->type, data_type);
+
+ TfLiteTensor* bias = nullptr;
+
+ // TODO(ahentz): At this point the optimized versions require 'bias'. We can
+ // either change that or document that convolution requires it.
+ TF_LITE_ENSURE(context, hasBias);
+
+ if (hasBias) {
+ bias = &context->tensors[node->inputs->data[2]];
+ if (data_type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
+ } else {
+ TF_LITE_ENSURE_EQ(context, bias->type, data_type);
+ }
+ TF_LITE_ENSURE_EQ(context, bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, bias->dims->data[0], filter->dims->data[0]);
+ }
+
+ int channels_out = filter->dims->data[0];
+ int width = input->dims->data[2];
+ int height = input->dims->data[1];
+ int filter_width = filter->dims->data[2];
+ int filter_height = filter->dims->data[1];
+ int batches = input->dims->data[0];
+
+ // Matching GetWindowedOutputSize in TensorFlow.
+ auto padding = params->padding;
+ auto computeOutSize = [padding](int imageSize, int filterSize,
+ int stride) -> int {
+ return padding == kTfLitePaddingSame
+ ? (imageSize + stride - 1) / stride
+ : padding == kTfLitePaddingValid
+ ? (imageSize - filterSize + stride) / stride
+ : 0;
+ };
+
+ int outWidth = computeOutSize(width, filter_width, params->stride_width);
+ int outHeight = computeOutSize(height, filter_height, params->stride_height);
+
+ data->padding.height =
+ ComputePadding(params->stride_height, height, filter_height, outHeight);
+ data->padding.width =
+ ComputePadding(params->stride_width, width, filter_width, outWidth);
+
+ TF_LITE_ENSURE(context, hasBias);
+
+ // Note that quantized inference requires that all tensors have their
+ // parameters set. This is usually done during quantized training.
+ if (data_type != kTfLiteFloat32) {
+ double real_multiplier = 0.0;
+ TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
+ context, input, filter, bias, output, &real_multiplier));
+ QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier,
+ &data->output_shift);
+ CalculateActivationRangeUint8(params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+ }
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+ output_size->data[0] = batches;
+ output_size->data[1] = outHeight;
+ output_size->data[2] = outWidth;
+ output_size->data[3] = channels_out;
+ auto output_status = context->ResizeTensor(context, output, output_size);
+
+ if (output_status != kTfLiteOk) return output_status;
+
+ // We don't always need to allocate im2col. It is only used in some versions
+ // of the optimized Conv. This test just mimics something that happens inside
+ // optimized_ops.h, in order to avoid a DCHECK(!im2col_data).
+ data->need_im2col =
+ (params->stride_width != 1 || params->stride_height != 1 ||
+ filter_width != 1 || filter_height != 1);
+ // If we're using the optimized multithreaded EigenTensor implementation of
+ // convolution, it expects the filter weights to be transposed compared to
+ // the normal TF Lite buffer format. Typical TF Lite weights are
+ // [filter_count, filter_height, filter_width, input_depth], but for the float
+ // implementation we need them as [filter_height, filter_width, input_depth,
+ // filter_count]. We get to that format by transposing, and create a temporary
+ // buffer to store the results.
+ // This path is only used for float processing, so only create the buffer if
+ // we're running with that data type.
+ data->need_hwcn_weights = (data_type == kTfLiteFloat32);
+
+ int temporaries_count = 0;
+ if (data->need_im2col) {
+ data->im2col_index = temporaries_count;
+ ++temporaries_count;
+ }
+ if (data->need_hwcn_weights) {
+ data->hwcn_weights_index = temporaries_count;
+ ++temporaries_count;
+ }
+
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(temporaries_count);
+
+ if (data->need_im2col) {
+ node->temporaries->data[data->im2col_index] = data->im2col_id;
+
+ TfLiteIntArray* im2col_size = TfLiteIntArrayCreate(4);
+
+ int input_depth = input->dims->data[3];
+ im2col_size->data[0] = output_size->data[0];
+ im2col_size->data[1] = output_size->data[1];
+ im2col_size->data[2] = output_size->data[2];
+ im2col_size->data[3] = input_depth * filter_height * filter_width;
+
+ TfLiteTensor* im2col =
+ &context->tensors[node->temporaries->data[data->im2col_index]];
+ im2col->type = data_type;
+ im2col->allocation_type = kTfLiteArenaRw;
+ auto im2col_status = context->ResizeTensor(context, im2col, im2col_size);
+ if (im2col_status != kTfLiteOk) return im2col_status;
+ }
+
+ if (data->need_hwcn_weights) {
+ node->temporaries->data[data->hwcn_weights_index] = data->hwcn_weights_id;
+ TfLiteIntArray* hwcn_weights_size = TfLiteIntArrayCreate(2);
+
+ // Because we're treating the filter weights as a matrix when we do the
+ // transpose, we allocate the buffer with a two-dimensional shape, where one
+ // dimension is the number of elements in each filter, and the second is the
+ // total number of filters.
+ int input_depth = input->dims->data[3];
+ hwcn_weights_size->data[0] = (filter_height * filter_width * input_depth);
+ hwcn_weights_size->data[1] = channels_out;
+
+ TfLiteTensor* hwcn_weights =
+ &context->tensors[node->temporaries->data[data->hwcn_weights_index]];
+ hwcn_weights->type = data_type;
+ hwcn_weights->allocation_type = kTfLiteDynamic;
+ // Make sure we release any previous allocations before we reallocate.
+ // TODO(petewarden): Persistent arenas would be a better fit for this, but
+ // they aren't fully implemented yet.
+ if (hwcn_weights->data.raw) {
+ free(hwcn_weights->data.raw);
+ hwcn_weights->data.raw = nullptr;
+ }
+ auto hwcn_weights_status =
+ context->ResizeTensor(context, hwcn_weights, hwcn_weights_size);
+ if (hwcn_weights_status != kTfLiteOk) return hwcn_weights_status;
+ hwcn_weights->data.raw = static_cast<char*>(malloc(hwcn_weights->bytes));
+
+ // TODO(petewarden): If Resize() is called when the size hasn't actually
+ // changed, this will do extra redundant work.
+ data->have_weights_been_transposed = false;
+ }
+
+ return kTfLiteOk;
+}
+
+template <KernelType kernel_type>
+void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteConvParams* params, OpData* data, TfLiteTensor* input,
+ TfLiteTensor* filter, TfLiteTensor* bias,
+ TfLiteTensor* im2col, TfLiteTensor* hwcn_weights,
+ TfLiteTensor* output) {
+ gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
+
+ auto input_offset = -input->params.zero_point;
+ auto filter_offset = -filter->params.zero_point;
+ auto output_offset = output->params.zero_point;
+
+ if (kernel_type == kReference) {
+ reference_ops::Conv(
+ GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
+ GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
+ GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
+ params->stride_height, data->padding.width, data->padding.height,
+ output_offset, data->output_multiplier, data->output_shift,
+ data->output_activation_min, data->output_activation_max,
+ GetTensorData<uint8_t>(output), GetTensorDims(output),
+ GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context);
+ } else {
+ optimized_ops::Conv(
+ GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
+ GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
+ GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
+ params->stride_height, data->padding.width, data->padding.height,
+ output_offset, data->output_multiplier, data->output_shift,
+ data->output_activation_min, data->output_activation_max,
+ GetTensorData<uint8_t>(output), GetTensorDims(output),
+ GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context);
+ }
+}
+
+template <KernelType kernel_type>
+void EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLiteConvParams* params, OpData* data, TfLiteTensor* input,
+ TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col,
+ TfLiteTensor* hwcn_weights, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRangeFloat(params->activation, &output_activation_min,
+ &output_activation_max);
+
+ const float* filter_data;
+ if (data->need_hwcn_weights) {
+ filter_data = GetTensorData<float>(hwcn_weights);
+ } else {
+ filter_data = GetTensorData<float>(filter);
+ }
+
+ if (kernel_type == kReference) {
+ reference_ops::Conv(
+ GetTensorData<float>(input), GetTensorDims(input), filter_data,
+ GetTensorDims(filter), GetTensorData<float>(bias), GetTensorDims(bias),
+ params->stride_width, params->stride_height, data->padding.width,
+ data->padding.height, output_activation_min, output_activation_max,
+ GetTensorData<float>(output), GetTensorDims(output),
+ GetTensorData<float>(im2col), GetTensorDims(im2col));
+ } else {
+ multithreaded_ops::Conv(
+ GetTensorData<float>(input), GetTensorDims(input), filter_data,
+ GetTensorDims(filter), GetTensorData<float>(bias), GetTensorDims(bias),
+ params->stride_width, params->stride_height, data->padding.width,
+ data->padding.height, params->padding, output_activation_min,
+ output_activation_max, GetTensorData<float>(output),
+ GetTensorDims(output), GetTensorData<float>(im2col),
+ GetTensorDims(im2col));
+ }
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* filter = &context->tensors[node->inputs->data[1]];
+ bool hasBias = node->inputs->size == 3;
+ TfLiteTensor* bias =
+ hasBias ? &context->tensors[node->inputs->data[2]] : nullptr;
+ TfLiteTensor* im2col =
+ data->need_im2col
+ ? &context->tensors[node->temporaries->data[data->im2col_index]]
+ : nullptr;
+ TfLiteTensor* hwcn_weights =
+ data->need_hwcn_weights
+ ? &context->tensors[node->temporaries->data[data->hwcn_weights_index]]
+ : nullptr;
+
+ if (data->need_hwcn_weights && !data->have_weights_been_transposed) {
+ TransposeFloatTensor(filter, hwcn_weights);
+ data->have_weights_been_transposed = true;
+ }
+
+ // TODO(aselle): Consider whether float conv and quantized conv should be
+ // separate ops to avoid dispatch overhead here.
+ switch (input->type) { // Already know in/outtypes are same.
+ case kTfLiteFloat32:
+ EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
+ im2col, hwcn_weights, output);
+ break;
+ case kTfLiteUInt8:
+ EvalQuantized<kernel_type>(context, node, params, data, input, filter,
+ bias, im2col, hwcn_weights, output);
+ break;
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace conv
+
+TfLiteRegistration* Register_CONVOLUTION_REF() {
+ static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
+ conv::Eval<conv::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT() {
+ static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
+ conv::Eval<conv::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_CONVOLUTION_NEON_OPT() {
+ static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
+ conv::Eval<conv::kNeonOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_CONV_2D() {
+#ifdef USE_NEON
+ return Register_CONVOLUTION_NEON_OPT();
+#else
+ return Register_CONVOLUTION_GENERIC_OPT();
+#endif
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc
new file mode 100644
index 0000000000..18d7a31d59
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/conv_test.cc
@@ -0,0 +1,440 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseConvolutionOpModel : public SingleOpModel {
+ public:
+ // TODO(ahentz): Also test different activation types, bias, padding types,
+ // stride values.
+ BaseConvolutionOpModel(
+ const TensorData& input, const TensorData& filter,
+ const TensorData& output, int stride_width = 2, int stride_height = 2,
+ enum Padding padding = Padding_VALID,
+ enum ActivationFunctionType activation = ActivationFunctionType_NONE) {
+ input_ = AddInput(input);
+ filter_ = AddInput(filter);
+
+ int bias_size = GetShape(filter_)[0];
+ if (input.type == TensorType_FLOAT32) {
+ bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
+ } else {
+ // This is a quantized version. The scale of 'bias' depends on the scales
+ // of input and filter. Supposedly this is correctly set during quantized
+ // training.
+ auto bias_scale = GetScale(input_) * GetScale(filter_);
+ TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
+ bias_ = AddInput(bias);
+ }
+
+ output_ = AddOutput(output);
+ if (input.type != TensorType_FLOAT32) {
+ // The following is required by quantized inference. It is the unittest's
+ // responsibility to make sure the output scale falls into the correct
+ // range.
+ CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_));
+ }
+
+ SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
+ CreateConv2DOptions(builder_, padding, stride_width,
+ stride_height, activation)
+ .Union());
+
+ BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
+ }
+
+ protected:
+ int input_;
+ int filter_;
+ int bias_;
+ int output_;
+};
+
+class ConvolutionOpModel : public BaseConvolutionOpModel {
+ public:
+ using BaseConvolutionOpModel::BaseConvolutionOpModel;
+
+ void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
+
+ void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+TEST(ConvolutionOpTest, SimpleTestFloat32) {
+ ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
+ {TensorType_FLOAT32, {3, 2, 2, 1}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, // first 2x2 filter
+ -1, 1, -1, 1, // second 2x2 filter
+ -1, -1, 1, 1, // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ }));
+}
+
+TEST(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) {
+ ConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 6, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}},
+ /*stride_width=*/3, /*stride_height=*/1);
+ m.SetInput({
+ 3, 2, 1, -1, -2, -3, //
+ 4, 3, 2, -2, -3, -4, //
+ 5, 4, 3, -3, -4, -5, //
+ });
+ m.SetFilter({
+ 1, 2, //
+ 3, 4, //
+ });
+ m.SetBias({-1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 30, -24, //
+ 40, -34, //
+ }));
+}
+
+TEST(ConvolutionOpTest, HandCalculatedFloat32) {
+ const int depth = 1;
+ const int image_width = 4;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int stride_width = 1;
+ const int stride_height = 1;
+ const Padding padding = Padding_SAME;
+ ConvolutionOpModel m(
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, stride_width, stride_height, padding);
+
+ // The image matrix is:
+ // | 1 | 2 | 3 | 4 |
+ // | 5 | 6 | 7 | 8 |
+ // | 9 | 10 | 11 | 12 |
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ // The filter matrix is:
+ // | 1 | 4 | 7 |
+ // | 2 | 5 | 8 |
+ // | 3 | 6 | 9 |
+ m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9});
+ // No bias for this test.
+ m.SetBias({0});
+
+ m.Invoke();
+ // We're sliding the 3x3 filter across the 3x4 image, with accesses outside
+ // the input set to zero because we're using the 'SAME' padding mode.
+ // The calculations behind the expected output are:
+ // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105
+ // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150
+ // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183
+ // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95
+ // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235
+ // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
+ // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
+ // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178
+ // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187
+ // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234
+ // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261
+ // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121
+ // This means we should end up with this matrix:
+ // | 105 | 150 | 183 | 95 |
+ // | 235 | 312 | 357 | 178 |
+ // | 187 | 234 | 261 | 121 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({105, 150, 183, 95, 235, 312, 357,
+ 178, 187, 234, 261, 121}));
+}
+
+TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) {
+ const int depth = 1;
+ const int image_width = 4;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int stride_width = 1;
+ const int stride_height = 1;
+ const Padding padding = Padding_SAME;
+ ConvolutionOpModel m(
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, stride_width, stride_height, padding);
+
+ // The image matrix is:
+ // | 1 | 2 | 3 | 4 |
+ // | 5 | 6 | 7 | 8 |
+ // | 9 | 10 | 11 | 12 |
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ // The filter matrix is:
+ // | 1 | 4 | 7 |
+ // | 2 | 5 | 8 |
+ // | 3 | 6 | 9 |
+ m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9});
+ // Bias is | 10 |.
+ m.SetBias({10});
+
+ m.Invoke();
+ // We're sliding the 3x3 filter across the 3x4 image, with accesses outside
+ // the input set to zero because we're using the 'SAME' padding mode.
+ // The calculations behind the expected output are:
+ // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)+10=115
+ // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)+10=160
+ // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)+10=193
+ // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)+10=105
+ // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)+10=245
+ // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)+10=322
+ // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)+10=367
+ // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)+10=188
+ // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)+10=197
+ // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)+10=244
+ // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)+10=271
+ // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)+10=131
+ // This means we should end up with this matrix:
+ // | 115 | 160 | 193 | 105 |
+ // | 245 | 322 | 367 | 188 |
+ // | 197 | 244 | 271 | 131 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({115, 160, 193, 105, 245, 322,
+ 367, 188, 197, 244, 271, 131}));
+}
+
+TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) {
+ const int depth = 1;
+ const int image_width = 4;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int stride_width = 1;
+ const int stride_height = 1;
+ const Padding padding = Padding_SAME;
+ ConvolutionOpModel m(
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, stride_width, stride_height, padding,
+ ActivationFunctionType_RELU);
+
+ // The image matrix is:
+ // | 1 | 2 | 3 | 4 |
+ // | 5 | 6 | 7 | 8 |
+ // | 9 | 10 | 11 | 12 |
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ // The filter matrix is:
+ // | 1 | 4 | 7 |
+ // | 2 | 5 | 8 |
+ // | 3 | 6 | 9 |
+ m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9});
+ // Bias is | -200 |.
+ m.SetBias({-200});
+
+ m.Invoke();
+ // We're sliding the 3x3 filter across the 3x4 image, with accesses outside
+ // the input set to zero because we're using the 'SAME' padding mode.
+ // The calculations behind the expected output are:
+ // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)-200=-95
+ // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)-200=-50
+ // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)-200=-17
+ // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)-200=-105
+ // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)-200=35
+ // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)-200=112
+ // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)-200=157
+ // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)-200=-22
+ // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)-200=-13
+ // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)-200=34
+ // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)-200=61
+ // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)-200=-79
+ // All negative values are gated to zero by the Relu activation function.
+ // This means we should end up with this matrix:
+ // | 0 | 0 | 0 | 0 |
+ // | 35 | 112 | 157 | 0 |
+ // | 0 | 34 | 61 | 0 |
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({0, 0, 0, 0, 35, 112, 157, 0, 0, 34, 61, 0}));
+}
+
+TEST(ConvolutionOpTest, HandCalculatedValidFloat32) {
+ const int depth = 1;
+ const int image_width = 4;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int stride_width = 1;
+ const int stride_height = 1;
+ const Padding padding = Padding_VALID;
+ ConvolutionOpModel m(
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, stride_width, stride_height, padding);
+
+ // The image matrix is:
+ // | 1 | 2 | 3 | 4 |
+ // | 5 | 6 | 7 | 8 |
+ // | 9 | 10 | 11 | 12 |
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ // The filter matrix is:
+ // | 1 | 4 | 7 |
+ // | 2 | 5 | 8 |
+ // | 3 | 6 | 9 |
+ m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9});
+ // No bias for this test.
+ m.SetBias({0});
+
+ m.Invoke();
+ // We're sliding the 3x3 filter across the 3x4 image, with no accesses outside
+ // the input because we're using the 'VALID' padding mode, giving a 2x1
+ // output.
+ // The calculations behind the expected output are:
+ // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
+ // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
+ // This means we should end up with this matrix:
+ // | 312 | 357 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({312, 357}));
+}
+
+class QuantizedConvolutionOpModel : public BaseConvolutionOpModel {
+ public:
+ using BaseConvolutionOpModel::BaseConvolutionOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(input_, data);
+ }
+
+ void SetFilter(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(filter_, data);
+ }
+
+ void SetBias(std::initializer_list<float> data) {
+ QuantizeAndPopulate<int32_t>(bias_, data);
+ }
+
+ std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+// In this tests we set the input and output scales so that the results
+// match exactly the 'non-quantized' version.
+TEST(ConvolutionOpTest, SimpleTestQuantized) {
+ QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64},
+ {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64},
+ {TensorType_UINT8, {}, -127, 128});
+ m.SetInput({
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, // first 2x2 filter
+ -1, 1, -1, 1, // second 2x2 filter
+ -1, -1, 1, 1, // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ },
+ 1e-5)));
+ // For good measure, let's also verify the quantized values:
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 145, 129, 132, //
+ 145, 129, 132, //
+ 144, 131, 130, //
+ 164, 131, 130, //
+ }));
+}
+
+TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) {
+ QuantizedConvolutionOpModel m({TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64},
+ {TensorType_UINT8, {1, 2, 2, 1}, -63.5, 64},
+ {TensorType_UINT8, {}, -127, 128},
+ /*stride_width=*/3, /*stride_height=*/1);
+ m.SetInput({
+ 3, 2, 1, -1, -2, -3, //
+ 4, 3, 2, -2, -3, -4, //
+ 5, 4, 3, -3, -4, -5, //
+ });
+ m.SetFilter({
+ 1, 2, //
+ 3, 4, //
+ });
+ m.SetBias({-1});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
+ 30, -24, //
+ 40, -34, //
+ })));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 157, 103, //
+ 167, 93, //
+ }));
+}
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
new file mode 100644
index 0000000000..15dbfe08c8
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -0,0 +1,289 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/kernels/padding.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace depthwise_conv {
+
+constexpr int kInputTensor = 0;
+constexpr int kFilterTensor = 1;
+constexpr int kBiasTensor = 2;
+constexpr int kOutputTensor = 0;
+
+// This file has three implementation of DepthwiseConv.
+enum KernelType {
+ kReference,
+ kGenericOptimized, // Neon-free
+ kNeonOptimized,
+};
+
+struct OpData {
+ TfLitePaddingValues padding;
+ // The scaling factor from input to output (aka the 'real multiplier') can
+ // be represented as a fixed point multipler plus a left shift.
+ int32_t output_multiplier;
+ int output_shift;
+ // The range of the fused activation layer. For example for kNone and
+ // uint8_t these would be 0 and 255.
+ int32_t output_activation_min;
+ int32_t output_activation_max;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ // This is a builtin op, so we don't use the contents in 'buffer', if any.
+ // Instead, we allocate a new object to carry information from Prepare() to
+ // Eval().
+ return new OpData;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ // TODO(ahentz): use could use GetOptionalInputTensor() here, but we need to
+ // decide whether we are OK with optional tensors being completely absent, as
+ // opposed to having -1 as their index.
+ bool hasBias = NumInputs(node) == 3;
+
+ TF_LITE_ENSURE(context, hasBias || NumInputs(node) == 2);
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
+ TfLiteTensor* bias = nullptr;
+
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 4);
+
+ // The parameter 'depth_multiplier' is redundant, so we check here to make
+ // sure it is consistent with the given dimensions.
+ TF_LITE_ENSURE_EQ(context,
+ params->depth_multiplier * SizeOfDimension(input, 3),
+ SizeOfDimension(filter, 3));
+
+ const TfLiteType data_type = input->type;
+ TF_LITE_ENSURE(context,
+ data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8);
+ TF_LITE_ENSURE_EQ(context, output->type, data_type);
+ TF_LITE_ENSURE_EQ(context, filter->type, data_type);
+
+ if (hasBias) {
+ bias = GetInput(context, node, kBiasTensor);
+ if (data_type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
+ } else {
+ TF_LITE_ENSURE_EQ(context, bias->type, data_type);
+ }
+ TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(filter, 3),
+ SizeOfDimension(bias, 0));
+ }
+
+ int channels_out = SizeOfDimension(filter, 3);
+ int width = SizeOfDimension(input, 2);
+ int height = SizeOfDimension(input, 1);
+ int filter_width = SizeOfDimension(filter, 2);
+ int filter_height = SizeOfDimension(filter, 1);
+ int batches = SizeOfDimension(input, 0);
+
+ // Matching GetWindowedOutputSize in TensorFlow.
+ auto padding = params->padding;
+ auto compute_out_size = [padding](int imageSize, int filterSize,
+ int stride) -> int {
+ return padding == kTfLitePaddingSame
+ ? (imageSize + stride - 1) / stride
+ : padding == kTfLitePaddingValid
+ ? (imageSize - filterSize + stride) / stride
+ : 0;
+ };
+
+ int out_width = compute_out_size(width, filter_width, params->stride_width);
+ int out_height =
+ compute_out_size(height, filter_height, params->stride_height);
+
+ data->padding.height =
+ ComputePadding(params->stride_height, height, filter_height, out_height);
+ data->padding.width =
+ ComputePadding(params->stride_width, width, filter_width, out_width);
+
+ // Note that quantized inference requires that all tensors have their
+ // parameters set. This is usually done during quantized training.
+ if (data_type != kTfLiteFloat32) {
+ double real_multiplier = 0.0;
+ TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
+ context, input, filter, bias, output, &real_multiplier));
+ QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier,
+ &data->output_shift);
+ CalculateActivationRangeUint8(params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+ }
+
+ TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4);
+ outputSize->data[0] = batches;
+ outputSize->data[1] = out_height;
+ outputSize->data[2] = out_width;
+ outputSize->data[3] = channels_out;
+ return context->ResizeTensor(context, output, outputSize);
+}
+
+template <KernelType kernel_type>
+void EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLiteDepthwiseConvParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* filter, TfLiteTensor* bias,
+ TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRangeFloat(params->activation, &output_activation_min,
+ &output_activation_max);
+
+ void (*depthwise_conv)(const float*, const Dims<4>&, const float*,
+ const Dims<4>&, const float*, const Dims<4>&, int, int,
+ int, int, int, float, float, float*, const Dims<4>&);
+ if (kernel_type == kReference) {
+ depthwise_conv = &reference_ops::DepthwiseConv;
+ } else {
+ depthwise_conv = &optimized_ops::DepthwiseConv;
+ }
+
+ depthwise_conv(
+ GetTensorData<float>(input), GetTensorDims(input),
+ GetTensorData<float>(filter), GetTensorDims(filter),
+ GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
+ params->stride_height, data->padding.width, data->padding.height,
+ params->depth_multiplier, output_activation_min, output_activation_max,
+ GetTensorData<float>(output), GetTensorDims(output));
+}
+
+template <KernelType kernel_type>
+void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteDepthwiseConvParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* filter,
+ TfLiteTensor* bias, TfLiteTensor* output) {
+ auto input_offset = -input->params.zero_point;
+ auto filter_offset = -filter->params.zero_point;
+ auto output_offset = output->params.zero_point;
+
+ void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*,
+ const Dims<4>&, int32, const int32*, const Dims<4>&,
+ int, int, int, int, int, int32, int32, int, int32,
+ int32, uint8*, const Dims<4>&);
+ if (kernel_type == kReference) {
+ depthwise_conv = &reference_ops::DepthwiseConv;
+ } else {
+ depthwise_conv = &optimized_ops::DepthwiseConv;
+ }
+
+ depthwise_conv(
+ GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
+ GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
+ GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
+ params->stride_height, data->padding.width, data->padding.height,
+ params->depth_multiplier, output_offset, data->output_multiplier,
+ data->output_shift, data->output_activation_min,
+ data->output_activation_max, GetTensorData<uint8_t>(output),
+ GetTensorDims(output));
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
+ TfLiteTensor* bias =
+ (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr;
+
+ // TODO(aselle): Consider whether float conv and quantized conv should be
+ // separate ops to avoid dispatch overhead here.
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
+ output);
+ break;
+ case kTfLiteUInt8:
+ EvalQuantized<kernel_type>(context, node, params, data, input, filter,
+ bias, output);
+ break;
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace depthwise_conv
+
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_REF() {
+ static TfLiteRegistration r = {
+ depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare,
+ depthwise_conv::Eval<depthwise_conv::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare,
+ depthwise_conv::Eval<depthwise_conv::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_NEON_OPT() {
+ static TfLiteRegistration r = {
+ depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare,
+ depthwise_conv::Eval<depthwise_conv::kNeonOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_DEPTHWISE_CONV_2D() {
+#ifdef USE_NEON
+ return Register_DEPTHWISE_CONVOLUTION_NEON_OPT();
+#else
+ return Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT();
+#endif
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
new file mode 100644
index 0000000000..39227b2811
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
@@ -0,0 +1,186 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
+ public:
+ // TODO(ahentz): Also test different activation types, bias, padding types,
+ // stride values.
+ BaseDepthwiseConvolutionOpModel(const TensorData& input,
+ const TensorData& filter,
+ const TensorData& output) {
+ input_ = AddInput(input);
+ filter_ = AddInput(filter);
+
+ int bias_size = GetShape(filter_)[3];
+ if (input.type == TensorType_FLOAT32) {
+ bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
+ } else {
+ // This is a quantized version. The scale of 'bias' depends on the scales
+ // of input and filter. Supposedly this is correctly set during quantized
+ // training.
+ auto bias_scale = GetScale(input_) * GetScale(filter_);
+ TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
+ bias_ = AddInput(bias);
+ }
+
+ output_ = AddOutput(output);
+ if (input.type != TensorType_FLOAT32) {
+ // The following is required by quantized inference. It is the unittest's
+ // responsibility to make sure the output scale falls into the correct
+ // range.
+ CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_));
+ }
+
+ int input_depth = GetShape(input_)[3];
+ int output_depth = GetShape(filter_)[3];
+ int depth_mul = output_depth / input_depth;
+
+ SetBuiltinOp(
+ BuiltinOperator_DEPTHWISE_CONV_2D,
+ BuiltinOptions_DepthwiseConv2DOptions,
+ CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
+ ActivationFunctionType_NONE)
+ .Union());
+
+ BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
+ }
+
+ protected:
+ int input_;
+ int filter_;
+ int bias_;
+ int output_;
+};
+
+class DepthwiseConvolutionOpModel : public BaseDepthwiseConvolutionOpModel {
+ public:
+ using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel;
+
+ void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
+
+ void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+TEST(DepthwiseConvolutionOpTest, SimpleTest) {
+ DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
+ {TensorType_FLOAT32, {1, 2, 2, 4}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ 1, 2, 7, 8, // column 1
+ 3, 4, 9, 10, // column 2
+ 5, 6, 11, 12, // column 3
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, //
+ -9, 10, -11, 12, //
+ 5, 6, 7, 8, //
+ 13, -14, 15, -16, //
+ });
+ m.SetBias({1, 2, 3, 4});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 71, -34, 99, -20, //
+ 91, -26, 127, -4, //
+ }));
+}
+
+class QuantizedDepthwiseConvolutionOpModel
+ : public BaseDepthwiseConvolutionOpModel {
+ public:
+ using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(input_, data);
+ }
+
+ void SetFilter(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(filter_, data);
+ }
+
+ void SetBias(std::initializer_list<float> data) {
+ QuantizeAndPopulate<int32_t>(bias_, data);
+ }
+
+ std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+// In this test we set the input and output scales so that the results match
+// exactly the 'non-quantized' version.
+TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
+ QuantizedDepthwiseConvolutionOpModel m(
+ {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
+ {TensorType_UINT8, {1, 2, 2, 4}, -63.5, 64},
+ {TensorType_UINT8, {}, -127, 128});
+
+ m.SetInput({
+ 1, 2, 7, 8, // column 1
+ 3, 4, 9, 10, // column 2
+ 5, 6, 11, 12, // column 3
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, //
+ -9, 10, -11, 12, //
+ 5, 6, 7, 8, //
+ 13, -14, 15, -16, //
+ });
+ m.SetBias({1, 2, 3, 4});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 71, -34, 99, -20, //
+ 91, -26, 127, -4, //
+ },
+ 1e-5)));
+ // For good measure, let's also verify the quantized values:
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 198, 93, 226, 107, //
+ 218, 101, 254, 123, //
+ }));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
new file mode 100644
index 0000000000..4e8cb396d4
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -0,0 +1,104 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Ops that looks up items from matrix.
+//
+// Input:
+// Tensor[0]: Row number to lookup, dim.size == 1, int32
+// Tensor[1]: 2-dimensional matrix of multi-dimensional items
+// dim.size >= 2, any data type.
+// first dimension is row, second dimension is column.
+//
+// Output:
+// Output.dim[0] == Tensor[0].dim[0], num of lookups
+// Output.dim[1] == Tensor[1].dim[1], num of items per row
+// Each item in output is a raw bytes copy of corresponding item in input.
+// When indices are out of bound, the ops will not succeed.
+//
+
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace embedding_lookup {
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* lookup = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
+ TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
+
+ TfLiteTensor* value = GetInput(context, node, 1);
+ TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
+
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
+
+ outputSize->data[0] = SizeOfDimension(lookup, 0);
+ outputSize->data[1] = SizeOfDimension(value, 1);
+ for (int i = 2; i < NumDimensions(value); i++) {
+ outputSize->data[i] = SizeOfDimension(value, i);
+ }
+ return context->ResizeTensor(context, output, outputSize);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* lookup = GetInput(context, node, 0);
+ TfLiteTensor* value = GetInput(context, node, 1);
+
+ const int row_size = SizeOfDimension(value, 0);
+ const int row_bytes = value->bytes / row_size;
+
+ for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
+ int idx = lookup->data.i32[i];
+ if (idx >= row_size || idx < 0) {
+ context->ReportError(context, "Embedding Lookup: index out of bounds.");
+ return kTfLiteError;
+ } else {
+ memcpy(output->data.raw + i * row_bytes,
+ value->data.raw + idx * row_bytes, row_bytes);
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace embedding_lookup
+
+TfLiteRegistration* Register_EMBEDDING_LOOKUP() {
+ static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare,
+ embedding_lookup::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
new file mode 100644
index 0000000000..6c770e7f71
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
@@ -0,0 +1,248 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Op that looks up items from a sparse tensor in an embedding matrix.
+// The sparse lookup tensor is represented by three individual tensors: lookup,
+// indices, and dense_shape. The representation assume that the corresponding
+// dense tensor would satisfy:
+// * dense.shape = dense_shape
+// * dense[tuple(indices[i])] = lookup[i]
+//
+// By convention, indices should be sorted.
+//
+// Options:
+// combiner: The reduction op (SUM, MEAN, SQRTN).
+// * SUM computes the weighted sum of the embedding results.
+// * MEAN is the weighted sum divided by the total weight.
+// * SQRTN is the weighted sum divided by the square root of the sum of the
+// squares of the weights.
+//
+// Input:
+// Tensor[0]: Ids to lookup, dim.size == 1, int32.
+// Tensor[1]: Indices, int32.
+// Tensor[2]: Dense shape, int32.
+// Tensor[3]: Weights to use for aggregation, float.
+// Tensor[4]: Params, a matrix of multi-dimensional items,
+// dim.size >= 2, float.
+//
+// Output:
+// A (dense) tensor representing the combined embeddings for the sparse ids.
+// For each row in the sparse tensor represented by (lookup, indices, shape)
+// the op looks up the embeddings for all ids in that row, multiplies them by
+// the corresponding weight, and combines these embeddings as specified in the
+// last dimension.
+//
+// Output.dim = [l0, ... , ln-1, e1, ..., em]
+// Where dense_shape == [l0, ..., ln] and Tensor[4].dim == [e0, e1, ..., em]
+//
+// For instance, if params is a 10x20 matrix and ids, weights are:
+//
+// [0, 0]: id 1, weight 2.0
+// [0, 1]: id 3, weight 0.5
+// [1, 0]: id 0, weight 1.0
+// [2, 3]: id 1, weight 3.0
+//
+// with combiner=MEAN, then the output will be a (3, 20) tensor where:
+//
+// output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
+// output[1, :] = (params[0, :] * 1.0) / 1.0
+// output[2, :] = (params[1, :] * 3.0) / 3.0
+//
+// When indices are out of bound, the op will not succeed.
+
+#include <algorithm>
+#include <cmath>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+
+namespace {
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 5);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* ids = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1);
+ TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32);
+
+ TfLiteTensor* indices = GetInput(context, node, 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2);
+ TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32);
+
+ TfLiteTensor* shape = GetInput(context, node, 2);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1);
+ TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32);
+
+ TfLiteTensor* weights = GetInput(context, node, 3);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1);
+ TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32);
+
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
+ SizeOfDimension(ids, 0));
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
+ SizeOfDimension(weights, 0));
+
+ TfLiteTensor* value = GetInput(context, node, 4);
+ TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
+
+ // Mark the output as a dynamic tensor.
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
+ output->allocation_type = kTfLiteDynamic;
+
+ return kTfLiteOk;
+}
+
+void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements,
+ float current_total_weight,
+ float current_squares_weight, int embedding_size,
+ float* output) {
+ if (combiner != kTfLiteCombinerTypeSum && num_elements > 0) {
+ float multiplier = 1.0;
+ switch (combiner) {
+ case kTfLiteCombinerTypeMean:
+ multiplier = current_total_weight;
+ break;
+ case kTfLiteCombinerTypeSqrtn:
+ multiplier = std::sqrt(current_squares_weight);
+ break;
+ default:
+ break;
+ }
+ for (int k = 0; k < embedding_size; k++) {
+ output[k] /= multiplier;
+ }
+ }
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteEmbeddingLookupSparseParams*>(node->builtin_data);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* ids = GetInput(context, node, 0);
+ TfLiteTensor* indices = GetInput(context, node, 1);
+ TfLiteTensor* dense_shape = GetInput(context, node, 2);
+ TfLiteTensor* weights = GetInput(context, node, 3);
+ TfLiteTensor* value = GetInput(context, node, 4);
+
+ const int lookup_rank = SizeOfDimension(indices, 1);
+ const int embedding_rank = NumDimensions(value);
+ const int num_lookups = SizeOfDimension(ids, 0);
+ const int num_rows = SizeOfDimension(value, 0);
+
+ // The last dimension gets replaced by the embedding.
+ const int output_rank = (lookup_rank - 1) + (embedding_rank - 1);
+
+ // Make sure that the actual dense shape of the sparse tensor represented by
+ // (loopkup, indices, dense_shape) is consistent.
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(dense_shape, 0), lookup_rank);
+
+ // Resize output tensor.
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
+ int k = 0;
+ int embedding_size = 1;
+ int lookup_size = 1;
+ for (int i = 0; i < lookup_rank - 1; i++, k++) {
+ const int dim = dense_shape->data.i32[i];
+ lookup_size *= dim;
+ output_shape->data[k] = dim;
+ }
+ for (int i = 1; i < embedding_rank; i++, k++) {
+ const int dim = SizeOfDimension(value, i);
+ embedding_size *= dim;
+ output_shape->data[k] = dim;
+ }
+ TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
+ const int output_size = lookup_size * embedding_size;
+ TfLiteTensorRealloc(output_size * sizeof(float), output);
+
+ tensor_utils::ZeroVector(output->data.f, output_size);
+
+ // Keep track of the current bucket for aggregation/combination.
+ int current_output_offset = 0;
+ float current_total_weight = 0.0;
+ float current_squares_weight = 0.0;
+ int num_elements = 0;
+
+ for (int i = 0; i < num_lookups; i++) {
+ int idx = ids->data.i32[i];
+ if (idx >= num_rows || idx < 0) {
+ context->ReportError(context,
+ "Embedding Lookup Sparse: index out of bounds.");
+ return kTfLiteError;
+ }
+
+ // Check where we need to aggregate.
+ const int example_indices_offset = i * lookup_rank;
+ int output_bucket = 0;
+ int stride = 1;
+ for (int k = (lookup_rank - 1) - 1; k >= 0; k--) {
+ output_bucket += indices->data.i32[example_indices_offset + k] * stride;
+ stride *= dense_shape->data.i32[k];
+ }
+ const int output_offset = output_bucket * embedding_size;
+
+ // If we are in a new aggregation bucket and the combiner is not the sum,
+ // go back and finalize the result of the previous bucket.
+ if (output_offset != current_output_offset) {
+ FinalizeAggregation(params->combiner, num_elements, current_total_weight,
+ current_squares_weight, embedding_size,
+ &output->data.f[current_output_offset]);
+
+ // Track next bucket.
+ num_elements = 0;
+ current_total_weight = 0.0;
+ current_squares_weight = 0.0;
+ current_output_offset = output_offset;
+ }
+
+ // Add element to aggregation.
+ ++num_elements;
+ const int example_embedding_offset = idx * embedding_size;
+ const float w = weights->data.f[i];
+ current_squares_weight += w * w;
+ current_total_weight += w;
+ for (int k = 0; k < embedding_size; k++) {
+ output->data.f[current_output_offset + k] +=
+ (value->data.f[example_embedding_offset + k] * w);
+ }
+ }
+
+ // Finalize last bucket.
+ FinalizeAggregation(params->combiner, num_elements, current_total_weight,
+ current_squares_weight, embedding_size,
+ &output->data.f[current_output_offset]);
+
+ return kTfLiteOk;
+}
+
+} // namespace
+
+TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE() {
+ static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc
new file mode 100644
index 0000000000..69d9c5cc7d
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc
@@ -0,0 +1,166 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite sparse lookup op.
+
+#include <cmath>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class EmbeddingLookupSparseOpModel : public SingleOpModel {
+ public:
+ EmbeddingLookupSparseOpModel(CombinerType type,
+ std::initializer_list<int> lookup_shape,
+ std::initializer_list<int> indices_shape,
+ std::initializer_list<int> dense_shape_shape,
+ std::initializer_list<int> value_shape) {
+ lookup_ = AddInput(TensorType_INT32);
+ indices_ = AddInput(TensorType_INT32);
+ dense_shape_ = AddInput(TensorType_INT32);
+ weights_ = AddInput(TensorType_FLOAT32);
+ value_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE,
+ BuiltinOptions_EmbeddingLookupSparseOptions,
+ CreateEmbeddingLookupSparseOptions(builder_, type).Union());
+ BuildInterpreter({lookup_shape, indices_shape, dense_shape_shape,
+ lookup_shape, value_shape});
+ }
+
+ void SetInput(std::initializer_list<int> lookup_data,
+ std::initializer_list<int> indices_data,
+ std::initializer_list<int> dense_shape_data,
+ std::initializer_list<float> weights_data) {
+ PopulateTensor(lookup_, lookup_data);
+ PopulateTensor(indices_, indices_data);
+ PopulateTensor(dense_shape_, dense_shape_data);
+ PopulateTensor(weights_, weights_data);
+ }
+
+ void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
+ TfLiteTensor* tensor = interpreter_->tensor(value_);
+ int rows = tensor->dims->data[0];
+ int columns = tensor->dims->data[1];
+ int features = tensor->dims->data[2];
+ for (int i = 0; i < rows; i++) {
+ for (int j = 0; j < columns; j++) {
+ for (int k = 0; k < features; k++) {
+ tensor->data.f[(i * columns + j) * features + k] = function(i, j, k);
+ }
+ }
+ }
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ private:
+ int lookup_;
+ int weights_;
+ int indices_;
+ int dense_shape_;
+ int value_;
+ int output_;
+};
+
+TEST(EmbeddingLookupOpTest, SimpleTest) {
+ EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 2}, {2}, {4, 3, 2});
+ m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
+ m.Set3DWeightMatrix(
+ [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({
+ 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1
+ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // -
+ 6.00, 6.06, 6.60, 6.66, 7.20, 7.26, // 2 * Row 3 + 4 * Row 0
+ })));
+}
+
+TEST(EmbeddingLookupOpTest, SimpleTestMean) {
+ EmbeddingLookupSparseOpModel m(CombinerType_MEAN, {3}, {3, 2}, {2},
+ {4, 3, 2});
+ m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
+ m.Set3DWeightMatrix(
+ [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({
+ 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1
+ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // -
+ 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // 2 * Row 3 + 4 * Row 0
+ })));
+}
+
+TEST(EmbeddingLookupOpTest, SimpleTestSqrtn) {
+ EmbeddingLookupSparseOpModel m(CombinerType_SQRTN, {3}, {3, 2}, {2},
+ {4, 3, 2});
+ m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
+ m.Set3DWeightMatrix(
+ [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
+ m.Invoke();
+
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({
+ 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1
+ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // -
+ 6.00f / std::sqrt(20.0f), 6.06f / std::sqrt(20.0f),
+ 6.60f / std::sqrt(20.0f), 6.66f / std::sqrt(20.0f),
+ 7.20f / std::sqrt(20.0f),
+ 7.26f /
+ std::sqrt(
+ 20.0f), // 2 * Row 3 + 4 * Row 0, // 2 * Row 3 + 4 * Row 0
+ })));
+}
+
+TEST(EmbeddingLookupOpTest, Indices3DTest) {
+ EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 3}, {3}, {4, 3, 2});
+ m.SetInput({1, 3, 0}, {0, 0, 0, 2, 0, 0, 2, 0, 1}, {3, 2, 2},
+ {1.0, 2.0, 4.0});
+ m.Set3DWeightMatrix(
+ [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({
+ 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, 0.00, 0.00, 0.00,
+ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
+ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 6.00, 6.06, 6.60,
+ 6.66, 7.20, 7.26, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
+ })));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+#ifdef OS_LINUX
+ tflite::LogToStderr();
+#endif
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
new file mode 100644
index 0000000000..8c030b0677
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
@@ -0,0 +1,94 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite Lookup op.
+
+#include <iomanip>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class EmbeddingLookupOpModel : public SingleOpModel {
+ public:
+ EmbeddingLookupOpModel(std::initializer_list<int> index_shape,
+ std::initializer_list<int> weight_shape) {
+ input_ = AddInput(TensorType_INT32);
+ weight_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0);
+ BuildInterpreter({index_shape, weight_shape});
+ }
+
+ void SetInput(std::initializer_list<int> data) {
+ PopulateTensor(input_, data);
+ }
+
+ void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
+ TfLiteTensor* tensor = interpreter_->tensor(weight_);
+ int rows = tensor->dims->data[0];
+ int columns = tensor->dims->data[1];
+ int features = tensor->dims->data[2];
+ for (int i = 0; i < rows; i++) {
+ for (int j = 0; j < columns; j++) {
+ for (int k = 0; k < features; k++) {
+ tensor->data.f[(i * columns + j) * features + k] = function(i, j, k);
+ }
+ }
+ }
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ private:
+ int input_;
+ int weight_;
+ int output_;
+};
+
+// TODO(ahentz): write more tests that exercise the details of the op, such as
+// lookup errors and variable input shapes.
+TEST(EmbeddingLookupOpTest, SimpleTest) {
+ EmbeddingLookupOpModel m({3}, {3, 2, 4});
+ m.PopulateTensor<int>(0, {1, 0, 2});
+ m.Set3DWeightMatrix(
+ [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ })));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
new file mode 100644
index 0000000000..a77fe94e49
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -0,0 +1,307 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/gemm_support.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace fully_connected {
+
+// This file has four implementations of FullyConnected
+enum KernelType {
+ kReference,
+ kGenericOptimized, // Neon-free
+ kNeonOptimized,
+ kPie, // Used by the PIE team
+};
+
+struct OpData {
+ // The scaling factor from input to output (aka the 'real multiplier') can
+ // be represented as a fixed point multipler plus a left shift.
+ int32_t output_multiplier;
+ int output_shift;
+ // The range of the fused activation layer. For example for kNone and
+ // uint8_t these would be 0 and 255.
+ int32_t output_activation_min;
+ int32_t output_activation_max;
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kWeightsTensor = 1;
+constexpr int kBiasTensor = 2;
+constexpr int kOutputTensor = 0;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ // This is a builtin op, so we don't use the contents in 'buffer', if any.
+ // Instead, we allocate a new object to carry information from Prepare() to
+ // Eval().
+ gemm_support::IncrementUsageCounter(context);
+ return new OpData;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ gemm_support::DecrementUsageCounter(context);
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ // Check we have all the inputs and outputs we need.
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 3);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
+ TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // Check all the parameters of tensor match within themselves and match the
+ // input configuration.
+ int input_size = 1;
+ for (int i = 0; i < input->dims->size; i++) {
+ input_size *= input->dims->data[i];
+ }
+
+ const int batch_size = input_size / filter->dims->data[1];
+ const int num_units = filter->dims->data[0];
+
+ TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]);
+ if (bias) {
+ TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units);
+ }
+
+ TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1);
+
+ // Note that quantized inference requires that all tensors have their
+ // parameters set. This is usually done during quantized training.
+ TfLiteType data_type = input->type;
+ if (data_type != kTfLiteFloat32) {
+ double real_multiplier = 0.0;
+ TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
+ context, input, filter, bias, output, &real_multiplier));
+ QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier,
+ &data->output_shift);
+ CalculateActivationRangeUint8(params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+ }
+
+ // Resize output.
+ TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
+ output_size_array->data[0] = batch_size;
+ output_size_array->data[1] = num_units;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size_array));
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* filter,
+ TfLiteTensor* bias, TfLiteTensor* output) {
+ int total_input_size = 1;
+ for (int i = 0; i < input->dims->size; i++) {
+ total_input_size *= input->dims->data[i];
+ }
+
+ int input_size = filter->dims->data[1];
+ const int batch_size = total_input_size / filter->dims->data[1];
+ const int num_units = filter->dims->data[0];
+
+ // Output = bias if bias tensor exists.
+ if (bias) {
+ tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
+ output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
+ }
+
+ // Compute output += weight * input
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ filter->data.f, num_units, input_size, input->data.f, batch_size,
+ output->data.f, /*result_stride=*/1);
+
+ // Apply activation function
+ tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units,
+ params->activation, output->data.f);
+
+ return kTfLiteOk;
+}
+
+#define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \
+ if (params->activation == kTfLiteActNone) { \
+ macro_name(target_namespace, kNone); \
+ } \
+ if (params->activation == kTfLiteActRelu) { \
+ macro_name(target_namespace, kRelu); \
+ } \
+ if (params->activation == kTfLiteActRelu6) { \
+ macro_name(target_namespace, kRelu6); \
+ }
+
+template <KernelType kernel_type>
+TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* filter,
+ TfLiteTensor* bias, TfLiteTensor* output) {
+ gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
+
+ int32_t input_offset = -input->params.zero_point;
+ int32_t filter_offset = -filter->params.zero_point;
+ int32_t output_offset = output->params.zero_point;
+#define TF_LITE_FULLY_CONNECTED(type) \
+ type::FullyConnected( \
+ GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset, \
+ GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset, \
+ GetTensorData<int32_t>(bias), GetTensorDims(bias), output_offset, \
+ data->output_multiplier, data->output_shift, \
+ data->output_activation_min, data->output_activation_max, \
+ GetTensorData<uint8_t>(output), GetTensorDims(output), gemm_context)
+ if (kernel_type == kReference) {
+ TF_LITE_FULLY_CONNECTED(reference_ops);
+ } else if (kernel_type == kPie) {
+ // TODO(ahentz): we don't have a quantized version of the PIE kernels, so
+ // we just defer to the MINI ones.
+ TF_LITE_FULLY_CONNECTED(optimized_ops);
+ } else {
+ TF_LITE_FULLY_CONNECTED(optimized_ops);
+ }
+#undef TF_LITE_FULLY_CONNECTED
+
+ return kTfLiteOk;
+}
+
+template <KernelType kernel_type>
+TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* filter,
+ TfLiteTensor* bias, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRangeFloat(params->activation, &output_activation_min,
+ &output_activation_max);
+#define TF_LITE_FULLY_CONNECTED(type) \
+ type::FullyConnected(GetTensorData<float>(input), GetTensorDims(input), \
+ GetTensorData<float>(filter), GetTensorDims(filter), \
+ GetTensorData<float>(bias), GetTensorDims(bias), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<float>(output), GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_FULLY_CONNECTED(reference_ops);
+ } else if (kernel_type == kPie) {
+ return EvalPie(context, node, params, data, input, filter, bias, output);
+ } else {
+ TF_LITE_FULLY_CONNECTED(optimized_ops);
+ }
+#undef TF_LITE_FULLY_CONNECTED
+
+ return kTfLiteOk;
+}
+
+#undef TF_LITE_MACRO_DISPATCH
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
+ TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ return EvalFloat<kernel_type>(context, node, params, data, input, filter,
+ bias, output);
+ case kTfLiteUInt8:
+ return EvalQuantized<kernel_type>(context, node, params, data, input,
+ filter, bias, output);
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace fully_connected
+
+TfLiteRegistration* Register_FULLY_CONNECTED_REF() {
+ static TfLiteRegistration r = {
+ fully_connected::Init, fully_connected::Free, fully_connected::Prepare,
+ fully_connected::Eval<fully_connected::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_FULLY_CONNECTED_NEON_OPT() {
+ static TfLiteRegistration r = {
+ fully_connected::Init, fully_connected::Free, fully_connected::Prepare,
+ fully_connected::Eval<fully_connected::kNeonOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ fully_connected::Init, fully_connected::Free, fully_connected::Prepare,
+ fully_connected::Eval<fully_connected::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_FULLY_CONNECTED_PIE() {
+ static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free,
+ fully_connected::Prepare,
+ fully_connected::Eval<fully_connected::kPie>};
+ return &r;
+}
+
+TfLiteRegistration* Register_FULLY_CONNECTED() {
+ // TODO(ahentz): We don't have a dedicated quantized version of the PIE
+ // kernel. For now, the quantized version just defer to the corresponding
+ // optimized MINI kernel. At some point we will allow different libraries to
+ // be built with different kernels, but for now we have to pick one here.
+ return Register_FULLY_CONNECTED_PIE();
+#ifdef USE_NEON
+ return Register_FULLY_CONNECTED_NEON_OPT();
+#else
+ return Register_FULLY_CONNECTED_GENERIC_OPT();
+#endif
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
new file mode 100644
index 0000000000..112e3f1ba0
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
@@ -0,0 +1,377 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite FULLY_CONNECTED op.
+
+#include <iomanip>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+static float fully_connected_input[] = {
+ 0.503691, 0.196961, 0.521017, 0.554248, 0.288678, 0.792476, 0.561653,
+ 0.462230, 0.650736, 0.163132, 0.029658, 0.411544, 0.470539, 0.572390,
+ 0.538755, 0.212030, 0.264309, 0.193908, 0.777480, 0.745661, 0.423314,
+ 0.470804, 0.175501, 0.492225, 0.192743, 0.540183, 0.372514, 0.446550,
+ 0.498173, 0.126472, 0.132706, 0.001864, 0.323433, 0.653723, 0.556112,
+ 0.612111, 0.446199, 0.117765, 0.074341, 0.096935, 0.280897, 0.103999,
+ 0.508479, 0.751437, 0.676389, 0.047234, 0.963467, 0.940698, 0.241142,
+ 0.740947, 0.686359, 0.664456, 0.211751, 0.861860, 0.156681, 0.404494,
+ 0.402043, 0.529195, 0.851044, 0.900216, 0.655667, 0.983750, 0.902081,
+ 0.979100, 0.637473, 0.458193, 0.591211, 0.083671, 0.575958, 0.665552,
+ 0.180606, 0.856856, 0.769551, 0.689086, 0.608293, 0.445940, 0.736320,
+ 0.571760, 0.386637, 0.977461, 0.312707, 0.072996, 0.641918, 0.524458,
+ 0.934856, 0.798598, 0.928951, 0.336899, 0.327793, 0.779995, 0.237115,
+ 0.983460, 0.763746, 0.139196, 0.962560, 0.401218, 0.597389, 0.553771,
+ 0.484890, 0.173347, 0.219322, 0.665496, 0.030203, 0.988873, 0.354582,
+ 0.638496, 0.434813, 0.090902, 0.210256, 0.821450, 0.068363, 0.522962,
+ 0.894446, 0.710280, 0.047420, 0.829302, 0.508879, 0.976371, 0.166202,
+ 0.836672, 0.756367, 0.403317, 0.820132, 0.520112, 0.542513, 0.782691,
+ 0.921330, 0.139902};
+
+static float fully_connected_golden_output[] = {
+ 0, 0.0732134, 0, 0, 0, 0.280859,
+ 0, 0.128927, 0, 0.0777251, 0, 0.270268,
+ 0.271435, 0.0173503, 0.335465, 0.235562,
+
+ 0, 0.0745866, 0, 0.051611, 0, 0.253876,
+ 0, 0.0814873, 0, 0.104104, 0, 0.248529,
+ 0.264194, 0, 0.302973, 0.166252,
+
+ 0, 0.0170409, 0, 0.0509851, 0, 0.212834,
+ 0, 0.0208326, 0, 0.129932, 0.203978, 0.103428,
+ 0.298051, 0, 0.332233, 0.00445903,
+
+ 0, 0.125246, 0, 0.0735336, 0, 0.0910256,
+ 0, 0, 0, 0.18933, 0.378111, 0.0712443,
+ 0.277298, 0.0123414, 0.267454, 0,
+
+ 0, 0.14687, 0, 0.155495, 0.0300215, 0.147256,
+ 0, 0, 0, 0.156412, 0.434914, 0.0461529,
+ 0.246508, 0, 0.363138, 0,
+
+ 0, 0, 0, 0.0212949, 0, 0.301708,
+ 0, 0.35497, 0, 0.406223, 0.0260211, 0.049195,
+ 0.197161, 0, 0.37316, 0,
+
+ 0, 0.221783, 0, 0, 0.0116515, 0.281945,
+ 0, 0, 0, 0, 0.285626, 0.181773,
+ 0.296401, 0.170452, 0.367135, 0.142597,
+
+ 0, 0, 0, 0, 0, 0.418886,
+ 0, 0.291063, 0, 0.227541, 0.0424759, 0.27589,
+ 0.398286, 0.177146, 0.40359, 0.121452,
+
+ 0, 0.0834884, 0, 0, 0, 0.287441,
+ 0, 0.0046838, 0, 0.0122087, 0, 0.217376,
+ 0.140183, 0.0948412, 0.436677, 0.0589876,
+
+ 0, 0.0289969, 0, 0.0921397, 0, 0.396802,
+ 0, 0.0126157, 0, 0.0968433, 0, 0.172271,
+ 0.173295, 0.0664741, 0.53645, 0.00915603,
+
+ 0, 0, 0, 0, 0, 0.147942,
+ 0, 0.263795, 0, 0.39782, 0, 0.382435,
+ 0.561072, 0.0579847, 0.145712, 0.13508,
+
+ 0, 0, 0, 0.16382, 0, 0.322294,
+ 0, 0.163798, 0, 0.405211, 0.367953, 0.076852,
+ 0.342473, 0.0834118, 0.377537, 0,
+
+ 0, 0.206, 0, 0, 0, 0.375769,
+ 0, 0, 0, 0, 0, 0.125165,
+ 0, 0.105591, 0.52055, 0.0536445,
+
+ 0, 0.259261, 0, 0, 0, 0.247707,
+ 0, 0, 0, 0, 0, 0.215862,
+ 0.149153, 0.224678, 0.359519, 0.129419,
+
+ 0, 0.17611, 0, 0.280895, 0, 0.576484,
+ 0, 0.000418848, 0, 0, 0, 0.151112,
+ 0.211902, 0, 0.566341, 0.106305,
+
+ 0, 0.0246284, 0, 0, 0, 0.196267,
+ 0, 0.0248624, 0, 0.265635, 0, 0.436199,
+ 0.408079, 0.134514, 0.328489, 0.411368};
+
+class BaseFullyConnectedOpModel : public SingleOpModel {
+ public:
+ // TODO(ahentz): test different activation types too.
+ BaseFullyConnectedOpModel(int units, int batches, const TensorData& input,
+ const TensorData& output = {TensorType_FLOAT32})
+ : batches_(batches), units_(units) {
+ int total_input_size = 1;
+ for (int i = 0; i < input.shape.size(); ++i) {
+ total_input_size *= input.shape[i];
+ }
+ input_size_ = total_input_size / batches_;
+
+ input_ = AddInput(input);
+ weights_ =
+ AddInput({input.type, {units_, input_size_}, input.min, input.max});
+
+ if (input.type == TensorType_FLOAT32) {
+ bias_ = AddInput({TensorType_FLOAT32, {units_}});
+ } else {
+ // This is a quantized version. The scale of 'bias' depends on the scales
+ // of input and filter. Supposedly this is correctly set during quantized
+ // training.
+ auto bias_scale = GetScale(input_) * GetScale(weights_);
+ TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale};
+ bias_ = AddInput(bias);
+ }
+
+ output_ = AddOutput(output);
+
+ SetBuiltinOp(
+ BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
+ CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
+ .Union());
+ BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
+ }
+
+ int input_size() { return input_size_; }
+ int num_units() { return units_; }
+ int num_batches() { return batches_; }
+
+ protected:
+ int input_;
+ int weights_;
+ int bias_;
+ int output_;
+
+ int batches_;
+ int units_;
+ int input_size_;
+};
+
+class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel {
+ public:
+ using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
+
+ void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+
+ void SetWeights(std::initializer_list<float> f) {
+ PopulateTensor(weights_, f);
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
+ public:
+ using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
+
+ void SetBias(std::initializer_list<float> data) {
+ QuantizeAndPopulate<int32_t>(bias_, data);
+ }
+ void SetWeights(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(weights_, data);
+ }
+ void SetInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(input_, data);
+ }
+
+ std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+// TODO(ahentz): add more small tests like this one, focused on making sure the
+// calculations are correct.
+TEST(FullyConnectedOpTest, SimpleTest) {
+ FloatFullyConnectedOpModel m(3, 2, {TensorType_FLOAT32, {2, 10}});
+ m.SetWeights({
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ });
+ m.SetBias({1, 2, 3});
+
+ m.SetInput({
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
+}
+
+TEST(FullyConnectedOpTest, SimpleTestQuantized) {
+ QuantizedFullyConnectedOpModel m(
+ 3, 2,
+ /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
+ /*output=*/{TensorType_UINT8, {}, -127, 128});
+
+ // input_product_scale < output_scale was not true.
+ m.SetWeights({
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ });
+ m.SetBias({1, 2, 3});
+
+ m.SetInput({
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
+ 24, 25, 26, //
+ 58, 59, 60, //
+ })));
+ EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
+}
+
+TEST(FullyConnectedOpTest, SimpleTest4DInput) {
+ // Note that it is not required that the first dimension be the number of
+ // batches. All we care is that the input can be evenly distributed in
+ // batches. In this case, we need the input to have multiples of '2'.
+ FloatFullyConnectedOpModel m(/*units=*/3,
+ /*batches=*/2,
+ /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}});
+ m.SetWeights({
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ });
+ m.SetBias({1, 2, 3});
+
+ m.SetInput({
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // first batch
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // second batch
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 24, 25, 26, // first batch
+ 58, 59, 60, // second batch
+ }));
+}
+
+TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) {
+ QuantizedFullyConnectedOpModel m(
+ 3, 2,
+ /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64},
+ /*output=*/{TensorType_UINT8, {}, -127, 128});
+
+ // input_product_scale < output_scale was not true.
+ m.SetWeights({
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ });
+ m.SetBias({1, 2, 3});
+
+ m.SetInput({
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
+ 24, 25, 26, //
+ 58, 59, 60, //
+ })));
+ EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
+}
+
+// TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard
+// to debug errors and doesn't necessarily test all the important details.
+TEST(FullyConnectedOpTest, BlackBoxTest) {
+ FloatFullyConnectedOpModel m(16, 2, {TensorType_FLOAT32, {2, 8}});
+ m.SetWeights(
+ {0.091327, 0.103366, -0.316505, -0.083120, 0.149366, -0.196636,
+ -0.123672, 0.062800, 0.063031, 0.191670, -0.062001, -0.061504,
+ -0.275581, 0.059388, -0.118497, -0.079224, 0.109758, 0.008307,
+ -0.062657, -0.060962, -0.049782, -0.106719, -0.319482, -0.103650,
+ 0.266455, 0.051517, -0.123448, 0.322464, 0.043282, -0.173782,
+ -0.190381, 0.002013, 0.096086, 0.131157, 0.031164, 0.100638,
+ -0.312191, -0.080923, -0.101318, -0.116614, 0.142238, 0.086540,
+ -0.139154, 0.174268, -0.073161, 0.080072, 0.006874, 0.229382,
+ -0.104321, -0.176035, -0.208587, -0.001019, -0.162032, 0.080824,
+ -0.025021, 0.074460, -0.252595, -0.161750, -0.136403, 0.008308,
+ 0.005710, 0.096600, 0.289839, 0.218816, -0.304651, -0.070958,
+ 0.054598, 0.147113, -0.139112, -0.072798, -0.163335, -0.167863,
+ -0.128762, -0.035780, 0.117262, 0.017177, 0.263335, -0.176612,
+ 0.262961, -0.093654, -0.339283, 0.333071, 0.180827, 0.287583,
+ 0.066350, -0.197947, -0.114449, -0.236035, 0.103532, -0.034284,
+ 0.093299, -0.145361, 0.054001, 0.250570, 0.157010, -0.143480,
+ -0.139061, -0.048873, 0.067557, 0.139038, 0.324106, 0.227041,
+ 0.037793, -0.225747, -0.241619, 0.357835, 0.135762, -0.306764,
+ -0.125982, 0.091916, 0.266587, 0.030135, 0.265148, 0.141627,
+ 0.020120, 0.083815, -0.124556, -0.100124, -0.048159, 0.181172,
+ 0.302309, -0.041084, 0.146334, -0.061511, -0.232605, 0.281324,
+ 0.145408, -0.221897});
+ m.SetBias({-0.160594, 0.205770, -0.078307, -0.077984, 0.001937, 0.015860,
+ 0.036810, 0.012346, 0.001028, 0.038551, 0.075415, 0.020804,
+ 0.048478, -0.032270, 0.175688, -0.085662});
+
+ const int input_sequence_size = sizeof(fully_connected_input) /
+ sizeof(float) /
+ (m.input_size() * m.num_batches());
+ for (int i = 0; i < input_sequence_size; i++) {
+ // TODO(ahentz): This is what the original test was doing: two equal
+ // batches per invocation. We could instead use two different batches.
+ float* batch_start = fully_connected_input + i * m.input_size();
+ float* batch_end = batch_start + m.input_size();
+ m.SetInput(0, batch_start, batch_end);
+ m.SetInput(m.input_size(), batch_start, batch_end);
+
+ m.Invoke();
+
+ float* golden_start = fully_connected_golden_output + i * m.num_units();
+ float* golden_end = golden_start + m.num_units();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc
new file mode 100644
index 0000000000..eb2b0aacf7
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/gemm_support.cc
@@ -0,0 +1,68 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/gemm_support.h"
+
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace gemm_support {
+
+struct RefCountedGemmContext {
+ gemmlowp::GemmContext* gemm_context_ = nullptr;
+ int num_references_ = 0;
+};
+
+void IncrementUsageCounter(TfLiteContext* context) {
+ auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ if (ptr == nullptr) {
+ ptr = new RefCountedGemmContext;
+ ptr->gemm_context_ = new gemmlowp::GemmContext();
+ ptr->num_references_ = 0;
+ context->gemm_context = ptr;
+ }
+ ptr->num_references_++;
+}
+
+void DecrementUsageCounter(TfLiteContext* context) {
+ auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ if (ptr == nullptr) {
+ TF_LITE_FATAL(
+ "Call to DecrementUsageCounter() not preceded by "
+ "IncrementUsageCounter()");
+ }
+ if (--ptr->num_references_ == 0) {
+ delete ptr->gemm_context_;
+ delete ptr;
+ context->gemm_context = nullptr;
+ }
+}
+
+gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) {
+ auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ if (ptr == nullptr) {
+ TF_LITE_FATAL(
+ "Call to GetFromContext() not preceded by IncrementUsageCounter()");
+ }
+ return ptr->gemm_context_;
+}
+
+void SetMaxNumThreads(TfLiteContext* context, int num_threads) {
+ IncrementUsageCounter(context);
+ GetFromContext(context)->set_max_num_threads(num_threads);
+ DecrementUsageCounter(context);
+}
+
+} // namespace gemm_support
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h
new file mode 100644
index 0000000000..b531959ffb
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/gemm_support.h
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
+
+#include "public/gemmlowp.h"
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+namespace gemm_support {
+
+// Returns the GemmContext stored in 'context', allowing multiple ops to
+// share a single object, as long as they share a TfLiteContext. The caller
+// must ensure that this is called between IncrementUsageCounter() and
+// DecrementUsageCounter(). For example, in the implementation of an op:
+// void* Init(TfLiteContext* context, const char*, size_t) {
+// gemm_support::IncrementUsageCounter(context);
+// return nullptr;
+// }
+// void Free(TfLiteContext* context, void*) {
+// gemm_support::DecrementUsageCounter(context);
+// }
+// TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+// auto* gemm_context = gemm_support::GetFromContext(context);
+// }
+gemmlowp::GemmContext* GetFromContext(TfLiteContext* context);
+
+// Let the framework know that the GemmContext stored in 'context' will be used
+// by an op. If necessary a new GemmContext is created and placed in 'context'.
+void IncrementUsageCounter(TfLiteContext* context);
+
+// Let the framework know that the op stopped using the GemmContext stored in
+// 'context'. If there are no more usages the GemmContext will be deleted.
+void DecrementUsageCounter(TfLiteContext* context);
+
+// Set the maximum number threads available for gemmlowp operations.
+void SetMaxNumThreads(TfLiteContext* context, int num_threads);
+
+} // namespace gemm_support
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
new file mode 100644
index 0000000000..3b82601d11
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
@@ -0,0 +1,155 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Op that looks up items from hashtable.
+//
+// Input:
+// Tensor[0]: Hash key to lookup, dim.size == 1, int32
+// Tensor[1]: Key of hashtable, dim.size == 1, int32
+// *MUST* be sorted in ascending order.
+// Tensor[2]: Value of hashtable, dim.size >= 1
+// Tensor[1].Dim[0] == Tensor[2].Dim[0]
+//
+// Output:
+// Output[0].dim[0] == Tensor[0].dim[0], num of lookups
+// Each item in output is a raw bytes copy of corresponding item in input.
+// When key does not exist in hashtable, the returned bytes are all 0s.
+//
+// Output[1].dim = { Tensor[0].dim[0] }, num of lookups
+// Each item indicates whether the corresponding lookup has a returned value.
+// 0 for missing key, 1 for found key.
+
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+
+namespace {
+
+int greater(const void* a, const void* b) {
+ return *static_cast<const int*>(a) - *static_cast<const int*>(b);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
+
+ TfLiteTensor* lookup = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
+ TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
+
+ TfLiteTensor* key = GetInput(context, node, 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1);
+ TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32);
+
+ TfLiteTensor* value = GetInput(context, node, 2);
+ TF_LITE_ENSURE(context, NumDimensions(value) >= 1);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0),
+ SizeOfDimension(value, 0));
+ if (value->type == kTfLiteString) {
+ TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1);
+ }
+
+ TfLiteTensor* hits = GetOutput(context, node, 1);
+ TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8);
+ TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1);
+ hitSize->data[0] = SizeOfDimension(lookup, 0);
+
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, value->type, output->type);
+
+ TfLiteStatus status = kTfLiteOk;
+ if (output->type != kTfLiteString) {
+ TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
+ outputSize->data[0] = SizeOfDimension(lookup, 0);
+ for (int i = 1; i < NumDimensions(value); i++) {
+ outputSize->data[i] = SizeOfDimension(value, i);
+ }
+ status = context->ResizeTensor(context, output, outputSize);
+ }
+ if (context->ResizeTensor(context, hits, hitSize) == kTfLiteError) {
+ status = kTfLiteError;
+ }
+ return status;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* hits = GetOutput(context, node, 1);
+ TfLiteTensor* lookup = GetInput(context, node, 0);
+ TfLiteTensor* key = GetInput(context, node, 1);
+ TfLiteTensor* value = GetInput(context, node, 2);
+
+ const int num_rows = SizeOfDimension(value, 0);
+ const int row_bytes = value->bytes / num_rows;
+ void* pointer = nullptr;
+ DynamicBuffer buf;
+
+ for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
+ int idx = -1;
+ pointer = bsearch(&(lookup->data.i32[i]), key->data.i32, num_rows,
+ sizeof(int32_t), greater);
+ if (pointer != nullptr) {
+ idx = (reinterpret_cast<char*>(pointer) - (key->data.raw)) /
+ sizeof(int32_t);
+ }
+
+ if (idx >= num_rows || idx < 0) {
+ if (output->type == kTfLiteString) {
+ buf.AddString(nullptr, 0);
+ } else {
+ memset(output->data.raw + i * row_bytes, 0, row_bytes);
+ }
+ hits->data.uint8[i] = 0;
+ } else {
+ if (output->type == kTfLiteString) {
+ buf.AddString(GetString(value, idx));
+ } else {
+ memcpy(output->data.raw + i * row_bytes,
+ value->data.raw + idx * row_bytes, row_bytes);
+ }
+ hits->data.uint8[i] = 1;
+ }
+ }
+ if (output->type == kTfLiteString) {
+ buf.WriteToTensor(output);
+ }
+
+ return kTfLiteOk;
+}
+} // namespace
+
+TfLiteRegistration* Register_HASHTABLE_LOOKUP() {
+ static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc
new file mode 100644
index 0000000000..916a23225e
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc
@@ -0,0 +1,176 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite Lookup op.
+
+#include <iomanip>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class HashtableLookupOpModel : public SingleOpModel {
+ public:
+ HashtableLookupOpModel(std::initializer_list<int> lookup_shape,
+ std::initializer_list<int> key_shape,
+ std::initializer_list<int> value_shape,
+ TensorType type) {
+ lookup_ = AddInput(TensorType_INT32);
+ key_ = AddInput(TensorType_INT32);
+ value_ = AddInput(type);
+ output_ = AddOutput(type);
+ hit_ = AddOutput(TensorType_UINT8);
+ SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0);
+ BuildInterpreter({lookup_shape, key_shape, value_shape});
+ }
+
+ void SetLookup(std::initializer_list<int> data) {
+ PopulateTensor<int>(lookup_, data);
+ }
+
+ void SetHashtableKey(std::initializer_list<int> data) {
+ PopulateTensor<int>(key_, data);
+ }
+
+ void SetHashtableValue(const std::vector<string>& content) {
+ PopulateStringTensor(value_, content);
+ }
+
+ void SetHashtableValue(const std::function<float(int)>& function) {
+ TfLiteTensor* tensor = interpreter_->tensor(value_);
+ int rows = tensor->dims->data[0];
+ for (int i = 0; i < rows; i++) {
+ tensor->data.f[i] = function(i);
+ }
+ }
+
+ void SetHashtableValue(const std::function<float(int, int)>& function) {
+ TfLiteTensor* tensor = interpreter_->tensor(value_);
+ int rows = tensor->dims->data[0];
+ int features = tensor->dims->data[1];
+ for (int i = 0; i < rows; i++) {
+ for (int j = 0; j < features; j++) {
+ tensor->data.f[i * features + j] = function(i, j);
+ }
+ }
+ }
+
+ std::vector<string> GetStringOutput() {
+ TfLiteTensor* output = interpreter_->tensor(output_);
+ int num = GetStringCount(output);
+ std::vector<string> result(num);
+ for (int i = 0; i < num; i++) {
+ auto ref = GetString(output, i);
+ result[i] = string(ref.str, ref.len);
+ }
+ return result;
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<uint8_t> GetHit() { return ExtractVector<uint8_t>(hit_); }
+
+ private:
+ int lookup_;
+ int key_;
+ int value_;
+ int output_;
+ int hit_;
+};
+
+// TODO(yichengfan): write more tests that exercise the details of the op,
+// such as lookup errors and variable input shapes.
+TEST(HashtableLookupOpTest, Test2DInput) {
+ HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32);
+
+ m.SetLookup({1234, -292, -11, 0});
+ m.SetHashtableKey({-11, 0, 1234});
+ m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 2.0, 2.1, // 2-nd item
+ 0, 0, // Not found
+ 0.0, 0.1, // 0-th item
+ 1.0, 1.1, // 1-st item
+ })));
+ EXPECT_THAT(m.GetHit(), ElementsAreArray({
+ 1, 0, 1, 1,
+ }));
+}
+
+TEST(HashtableLookupOpTest, Test1DInput) {
+ HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32);
+
+ m.SetLookup({1234, -292, -11, 0});
+ m.SetHashtableKey({-11, 0, 1234});
+ m.SetHashtableValue([](int i) { return i * i / 10.0f; });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.4, // 2-nd item
+ 0, // Not found
+ 0.0, // 0-th item
+ 0.1, // 1-st item
+ })));
+ EXPECT_THAT(m.GetHit(), ElementsAreArray({
+ 1,
+ 0,
+ 1,
+ 1,
+ }));
+}
+
+TEST(HashtableLookupOpTest, TestString) {
+ HashtableLookupOpModel m({4}, {3}, {3}, TensorType_STRING);
+
+ m.SetLookup({1234, -292, -11, 0});
+ m.SetHashtableKey({-11, 0, 1234});
+ m.SetHashtableValue({"Hello", "", "Hi"});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({
+ "Hi", // 2-nd item
+ "", // Not found
+ "Hello", // 0-th item
+ "", // 1-st item
+ }));
+ EXPECT_THAT(m.GetHit(), ElementsAreArray({
+ 1,
+ 0,
+ 1,
+ 1,
+ }));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
new file mode 100644
index 0000000000..288534099b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -0,0 +1,359 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+
+tflite_deps_intel = [
+ "@arm_neon_2_x86_sse",
+]
+
+NEON_FLAGS_IF_APPLICABLE = select({
+ ":arm": [
+ "-O3",
+ "-mfpu=neon",
+ "-mfloat-abi=softfp",
+ ],
+ ":armeabi-v7a": [
+ "-O3",
+ "-mfpu=neon",
+ "-mfloat-abi=softfp",
+ ],
+ ":armv7a": [
+ "-O3",
+ "-mfpu=neon",
+ "-mfloat-abi=softfp",
+ ],
+ "//conditions:default": [
+ "-O3",
+ ],
+})
+
+cc_library(
+ name = "types",
+ srcs = [],
+ hdrs = [
+ "compatibility.h",
+ "types.h",
+ ],
+)
+
+config_setting(
+ name = "arm",
+ values = {
+ "cpu": "arm",
+ },
+)
+
+config_setting(
+ name = "arm64-v8a",
+ values = {
+ "cpu": "arm64-v8a",
+ },
+)
+
+config_setting(
+ name = "armv7a",
+ values = {
+ "cpu": "armv7a",
+ },
+)
+
+config_setting(
+ name = "armeabi-v7a",
+ values = {
+ "cpu": "armeabi-v7a",
+ },
+)
+
+config_setting(
+ name = "haswell",
+ values = {
+ "cpu": "haswell",
+ },
+)
+
+config_setting(
+ name = "ios_x86_64",
+ values = {
+ "cpu": "ios_x86_64",
+ },
+)
+
+config_setting(
+ name = "ios_armv7",
+ values = {
+ "cpu": "ios_armv7",
+ },
+)
+
+config_setting(
+ name = "ios_arm64",
+ values = {
+ "cpu": "ios_arm64",
+ },
+)
+
+config_setting(
+ name = "k8",
+ values = {
+ "cpu": "k8",
+ },
+)
+
+config_setting(
+ name = "x86",
+ values = {
+ "cpu": "x86",
+ },
+)
+
+config_setting(
+ name = "x86_64",
+ values = {
+ "cpu": "x86_64",
+ },
+)
+
+config_setting(
+ name = "darwin",
+ values = {
+ "cpu": "darwin",
+ },
+)
+
+cc_library(
+ name = "optimized_base",
+ srcs = [],
+ hdrs = [
+ "common.h",
+ "optimized/depthwiseconv_float.h",
+ "optimized/depthwiseconv_uint8.h",
+ "optimized/optimized_ops.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ ":types",
+ ":round",
+ "//third_party/eigen3",
+ "@gemmlowp//:gemmlowp",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ ] + select({
+ ":haswell": tflite_deps_intel,
+ ":ios_x86_64": tflite_deps_intel,
+ ":k8": tflite_deps_intel,
+ ":x86": tflite_deps_intel,
+ ":x86_64": tflite_deps_intel,
+ ":darwin": tflite_deps_intel,
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
+ name = "optimized",
+ hdrs = [
+ "optimized/eigen_spatial_convolutions.h",
+ "optimized/eigen_tensor_reduced_instantiations_oss.h",
+ "optimized/multithreaded_conv.h",
+ "tensor.h",
+ ],
+ deps = [
+ ":optimized_base",
+ ":types",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:context",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_test(
+ name = "tensor_test",
+ srcs = ["tensor_test.cc"],
+ deps = [
+ ":reference",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "round",
+ srcs = [],
+ hdrs = ["round.h"],
+)
+
+cc_library(
+ name = "quantization_util",
+ srcs = ["quantization_util.cc"],
+ hdrs = [
+ "compatibility.h",
+ "quantization_util.h",
+ ],
+ deps = [":round"],
+)
+
+cc_test(
+ name = "quantization_util_test",
+ srcs = ["quantization_util_test.cc"],
+ deps = [
+ ":quantization_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "reference_base",
+ srcs = [],
+ hdrs = [
+ "common.h",
+ "reference/depthwiseconv_float.h",
+ "reference/depthwiseconv_uint8.h",
+ "reference/reference_ops.h",
+ ],
+ deps = [
+ ":round",
+ ":types",
+ "//third_party/eigen3",
+ "@gemmlowp//:gemmlowp",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ ] + select({
+ ":haswell": tflite_deps_intel,
+ ":ios_x86_64": tflite_deps_intel,
+ ":k8": tflite_deps_intel,
+ ":x86": tflite_deps_intel,
+ ":x86_64": tflite_deps_intel,
+ ":darwin": tflite_deps_intel,
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
+ name = "reference",
+ hdrs = ["tensor.h"],
+ deps = [
+ ":types",
+ "//tensorflow/contrib/lite:context",
+ ],
+)
+
+cc_library(
+ name = "portable_tensor_utils",
+ srcs = [
+ "reference/portable_tensor_utils.cc",
+ ],
+ hdrs = [
+ "reference/portable_tensor_utils.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/kernels:activation_functor",
+ "//tensorflow/contrib/lite/kernels:op_macros",
+ ],
+)
+
+cc_library(
+ name = "neon_tensor_utils",
+ srcs = [
+ "optimized/neon_tensor_utils.cc",
+ ],
+ hdrs = [
+ "optimized/neon_tensor_utils.h",
+ "optimized/tensor_utils_impl.h",
+ ],
+ copts = NEON_FLAGS_IF_APPLICABLE,
+ deps = [
+ ":cpu_check",
+ ":portable_tensor_utils",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/kernels:activation_functor",
+ ],
+)
+
+cc_library(
+ name = "tensor_utils",
+ srcs = [
+ "tensor_utils.cc",
+ ],
+ hdrs = [
+ "optimized/tensor_utils_impl.h",
+ "reference/portable_tensor_utils.h",
+ "tensor_utils.h",
+ ],
+ copts = NEON_FLAGS_IF_APPLICABLE,
+ deps = [
+ "//tensorflow/contrib/lite/kernels:activation_functor",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ ] + select({
+ ":arm": [
+ ":neon_tensor_utils",
+ ],
+ ":arm64-v8a": [
+ ":neon_tensor_utils",
+ ],
+ ":armeabi-v7a": [
+ ":neon_tensor_utils",
+ ],
+ ":armv7a": [
+ ":neon_tensor_utils",
+ ],
+ ":ios_armv7": [
+ ":neon_tensor_utils",
+ ],
+ ":ios_arm64": [
+ ":neon_tensor_utils",
+ ],
+ "//conditions:default": [
+ ":portable_tensor_utils",
+ ],
+ }),
+)
+
+cc_test(
+ name = "tensor_utils_test",
+ srcs = ["tensor_utils_test.cc"],
+ copts = NEON_FLAGS_IF_APPLICABLE,
+ linkopts = select({
+ "//tensorflow:android": [
+ "-fPIE -pie",
+ ],
+ "//conditions:default": [],
+ }),
+ linkstatic = 1,
+ deps = [
+ ":tensor_utils",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "cpu_check",
+ hdrs = [
+ "optimized/cpu_check.h",
+ ],
+ deps = [
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "@androidndk//:cpufeatures",
+ ],
+ "//conditions:default": [],
+ },
+ ),
+)
+
+exports_files(["optimized/eigen_tensor_reduced_instantiations_oss.h"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
new file mode 100644
index 0000000000..28f19a2506
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -0,0 +1,107 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
+
+#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
+#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#endif
+#endif
+
+#ifndef USE_NEON
+#if defined(__ARM_NEON__) || defined(__ARM_NEON)
+#define USE_NEON
+#include <arm_neon.h>
+#endif
+
+#if defined __GNUC__ && defined __SSE4_1__
+#define USE_NEON
+
+#define OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+#pragma GCC diagnostic ignored "-Wattributes"
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wnarrowing"
+#pragma GCC diagnostic ignored "-Wsequence-point"
+
+#include "NEON_2_SSE.h"
+
+#pragma GCC diagnostic pop
+#endif
+#endif
+
+#include "public/gemmlowp.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+inline void GetActivationMinMax(FusedActivationFunctionType ac,
+ float* output_activation_min,
+ float* output_activation_max) {
+ switch (ac) {
+ case FusedActivationFunctionType::kNone:
+ *output_activation_min = std::numeric_limits<float>::lowest();
+ *output_activation_max = std::numeric_limits<float>::max();
+ break;
+ case FusedActivationFunctionType::kRelu:
+ *output_activation_min = 0.f;
+ *output_activation_max = std::numeric_limits<float>::max();
+ break;
+ case FusedActivationFunctionType::kRelu1:
+ *output_activation_min = -1.f;
+ *output_activation_max = 1.f;
+ break;
+ case FusedActivationFunctionType::kRelu6:
+ *output_activation_min = 0.f;
+ *output_activation_max = 6.f;
+ break;
+ }
+}
+
+inline float ActivationFunctionWithMinMax(float x, float output_activation_min,
+ float output_activation_max) {
+ return std::min(std::max(x, output_activation_min), output_activation_max);
+}
+
+// Legacy function, left for compatibility only.
+template <FusedActivationFunctionType Ac>
+float ActivationFunction(float x) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ return ActivationFunctionWithMinMax(x, output_activation_min,
+ output_activation_max);
+}
+
+inline int32 MultiplyByQuantizedMultiplierSmallerThanOne(
+ int32 x, int32 quantized_multiplier, int right_shift) {
+ using gemmlowp::RoundingDivideByPOT;
+ using gemmlowp::SaturatingRoundingDoublingHighMul;
+ return RoundingDivideByPOT(
+ SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift);
+}
+
+inline int32 MultiplyByQuantizedMultiplierGreaterThanOne(
+ int32 x, int32 quantized_multiplier, int left_shift) {
+ using gemmlowp::SaturatingRoundingDoublingHighMul;
+ return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
+ quantized_multiplier);
+}
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h
new file mode 100644
index 0000000000..796a03566a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h
@@ -0,0 +1,78 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
+
+#include <cassert>
+#include <cstdint>
+#include <cstdlib>
+
+#ifndef TFLITE_DCHECK
+#define TFLITE_DCHECK(condition) (condition) ? (void)0 : assert(false)
+#endif
+
+#ifndef TFLITE_DCHECK_EQ
+#define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : assert(false)
+#endif
+
+#ifndef TFLITE_DCHECK_GE
+#define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : assert(false)
+#endif
+
+#ifndef TFLITE_DCHECK_GT
+#define TFLITE_DCHECK_GT(x, y) ((x) > (y)) ? (void)0 : assert(false)
+#endif
+
+#ifndef TFLITE_DCHECK_LE
+#define TFLITE_DCHECK_LE(x, y) ((x) <= (y)) ? (void)0 : assert(false)
+#endif
+
+#ifndef TFLITE_DCHECK_LT
+#define TFLITE_DCHECK_LT(x, y) ((x) < (y)) ? (void)0 : assert(false)
+#endif
+
+// TODO(ahentz): Clean up: We should stick to the DCHECK versions.
+#ifndef TFLITE_CHECK
+#define TFLITE_CHECK(condition) (condition) ? (void)0 : abort()
+#endif
+
+#ifndef TFLITE_CHECK_EQ
+#define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : abort()
+#endif
+
+#ifndef TFLITE_CHECK_GE
+#define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : abort()
+#endif
+
+#ifndef TFLITE_CHECK_GT
+#define TFLITE_CHECK_GT(x, y) ((x) > (y)) ? (void)0 : abort()
+#endif
+
+#ifndef TFLITE_CHECK_LE
+#define TFLITE_CHECK_LE(x, y) ((x) <= (y)) ? (void)0 : abort()
+#endif
+
+#ifndef TFLITE_CHECK_LT
+#define TFLITE_CHECK_LT(x, y) ((x) < (y)) ? (void)0 : abort()
+#endif
+
+// TODO(ahentz): Clean up.
+using uint8 = std::uint8_t;
+using int16 = std::int16_t;
+using uint16 = std::uint16_t;
+using int32 = std::int32_t;
+using uint32 = std::uint32_t;
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
new file mode 100644
index 0000000000..dea46cc120
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_
+
+namespace tflite {
+
+#ifdef __ANDROID__
+#include "ndk/sources/android/cpufeatures/cpu-features.h"
+
+// Runtime check for Neon support on Android.
+inline bool TestCPUFeatureNeon() {
+#ifdef __aarch64__
+ // ARM-64 always has NEON support.
+ return true;
+#else
+ static bool kUseAndroidNeon =
+ (android_getCpuFamily() == ANDROID_CPU_FAMILY_ARM &&
+ android_getCpuFeatures() & ANDROID_CPU_ARM_FEATURE_ARMv7 &&
+ android_getCpuFeatures() & ANDROID_CPU_ARM_FEATURE_NEON);
+ return kUseAndroidNeon;
+#endif // __aarch64__
+}
+
+#elif __ARM_NEON
+
+inline bool TestCPUFeatureNeon() {
+ return true;
+}
+
+#else
+
+inline bool TestCPUFeatureNeon() {
+ return false;
+}
+
+#endif
+
+} // namespace tflite
+
+// NEON_OR_PORTABLE(SomeFunc, arcs) calls NeonSomeFunc(args) if Neon is both
+// enabled at build time and detected at runtime, or PortableSomeFunc(args)
+// otherwise.
+#ifdef __ARM_ARCH_5TE__
+// Neon isn't available at all on ARMv5.
+#define NEON_OR_PORTABLE(funcname, ...) Portable##funcname(__VA_ARGS__)
+#else
+#define NEON_OR_PORTABLE(funcname, ...) \
+ TestCPUFeatureNeon() ? Neon##funcname(__VA_ARGS__) \
+ : Portable##funcname(__VA_ARGS__)
+#endif
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
new file mode 100644
index 0000000000..974611f52a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
@@ -0,0 +1,987 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
+
+#include "public/gemmlowp.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace optimized_ops {
+
+// Implementation of float DepthwiseConv
+
+template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
+struct FloatDepthwiseConvKernel {};
+
+#ifdef USE_NEON
+
+template <>
+struct FloatDepthwiseConvKernel<false, 8, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Load the filters
+ float32x4_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vld1q_f32(filter_ptr + 4 * i);
+ }
+ int outp = 0;
+ // Handle 2 output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the inputs
+ float32x4_t input[4];
+ for (int i = 0; i < 4; i++) {
+ input[i] = vld1q_f32(input_ptr + 4 * i);
+ }
+ input_ptr += 16;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmlaq_f32(acc[0], input[0], filter[0]);
+ acc[1] = vmlaq_f32(acc[1], input[1], filter[1]);
+ acc[2] = vmlaq_f32(acc[2], input[2], filter[0]);
+ acc[3] = vmlaq_f32(acc[3], input[3], filter[1]);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the inputs
+ float32x4_t input[2];
+ for (int i = 0; i < 2; i++) {
+ input[i] = vld1q_f32(input_ptr + 4 * i);
+ }
+ input_ptr += 8;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vmlaq_f32(acc[i], input[i], filter[i]);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<false, 2, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ const float32x2_t filters = vld1_f32(filter_ptr);
+ const float32x4_t filters_dup2 = vcombine_f32(filters, filters);
+ int outp = 0;
+ // Handle 8 output pixels at a time.
+ for (; outp <= num_output_pixels - 8; outp += 8) {
+ // Load the inputs
+ float32x4_t input[4];
+ for (int i = 0; i < 4; i++) {
+ input[i] = vld1q_f32(input_ptr + 4 * i);
+ }
+ input_ptr += 16;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle 4 output pixels at a time.
+ for (; outp <= num_output_pixels - 4; outp += 4) {
+ // Load the inputs
+ float32x4_t input[2];
+ for (int i = 0; i < 2; i++) {
+ input[i] = vld1q_f32(input_ptr + 4 * i);
+ }
+ input_ptr += 8;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ // Handle 2 output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the inputs
+ const float32x4_t input = vld1q_f32(input_ptr);
+ input_ptr += 4;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc = vld1q_f32(acc_buffer_ptr);
+ // Multiply-accumulate
+ acc = vmlaq_f32(acc, input, filters_dup2);
+ // Store the accumulators back to acc_buffer
+ vst1q_f32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+ // Handle 1 output pixel at a time
+ for (; outp < num_output_pixels; outp++) {
+ // Load the inputs
+ const float32x2_t input = vld1_f32(input_ptr);
+ input_ptr += 2;
+ // Load the accumulators from acc_buffer
+ float32x2_t acc = vld1_f32(acc_buffer_ptr);
+ // Multiply-accumulate
+ acc = vmla_f32(acc, input, filters);
+ // Store the accumulators back to acc_buffer
+ vst1_f32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 2;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 0, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ const float* local_filter_ptr = filter_ptr;
+ const float* local_input_ptr = input_ptr;
+ int ic = 0;
+ // Handle 16 input channels at a time.
+ for (; ic <= input_depth - 16; ic += 16) {
+ // Load the filters
+ float32x4_t filter_0 = vld1q_f32(local_filter_ptr + 4 * 0);
+ float32x4_t filter_1 = vld1q_f32(local_filter_ptr + 4 * 1);
+ float32x4_t filter_2 = vld1q_f32(local_filter_ptr + 4 * 2);
+ float32x4_t filter_3 = vld1q_f32(local_filter_ptr + 4 * 3);
+ local_filter_ptr += 16;
+ // Load the inputs
+ float32x4_t input_0 = vld1q_f32(local_input_ptr + 4 * 0);
+ float32x4_t input_1 = vld1q_f32(local_input_ptr + 4 * 1);
+ float32x4_t input_2 = vld1q_f32(local_input_ptr + 4 * 2);
+ float32x4_t input_3 = vld1q_f32(local_input_ptr + 4 * 3);
+ local_input_ptr += 16;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0);
+ float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1);
+ float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2);
+ float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3);
+ // Multiply-accumulate
+ acc_0 = vmlaq_f32(acc_0, input_0, filter_0);
+ acc_1 = vmlaq_f32(acc_1, input_1, filter_1);
+ acc_2 = vmlaq_f32(acc_2, input_2, filter_2);
+ acc_3 = vmlaq_f32(acc_3, input_3, filter_3);
+ // Store the accumulators back to acc_buffer
+ vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0);
+ vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1);
+ vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2);
+ vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3);
+ acc_buffer_ptr += 16;
+ }
+ // Handle 4 input channels at a time.
+ for (; ic <= input_depth - 4; ic += 4) {
+ // Load the filters
+ float32x4_t filter;
+ filter = vld1q_f32(local_filter_ptr);
+ local_filter_ptr += 4;
+ // Load the inputs
+ float32x4_t input;
+ input = vld1q_f32(local_input_ptr);
+ local_input_ptr += 4;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc;
+ acc = vld1q_f32(acc_buffer_ptr);
+ // Multiply-accumulate
+ acc = vmlaq_f32(acc, input, filter);
+ // Store the accumulators back to acc_buffer
+ vst1q_f32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+ // Handle one input channel at a time.
+ for (; ic < input_depth; ic++) {
+ const float input_val = *local_input_ptr++;
+ const float filter_val = *local_filter_ptr++;
+ *acc_buffer_ptr++ += filter_val * input_val;
+ }
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 0, 8> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ const float* local_filter_ptr = filter_ptr;
+ const float* local_input_ptr = input_ptr;
+ int ic = 0;
+ // Handle 2 input channels at a time.
+ for (; ic <= input_depth - 2; ic += 2) {
+ // Load the filters
+ float32x4_t filter[4];
+ for (int i = 0; i < 4; i++) {
+ filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
+ }
+ local_filter_ptr += 16;
+ // Load the inputs
+ const float32x2_t input = vld1_f32(local_input_ptr);
+ local_input_ptr += 2;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmlaq_lane_f32(acc[0], filter[0], input, 0);
+ acc[1] = vmlaq_lane_f32(acc[1], filter[1], input, 0);
+ acc[2] = vmlaq_lane_f32(acc[2], filter[2], input, 1);
+ acc[3] = vmlaq_lane_f32(acc[3], filter[3], input, 1);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one input channel at a time.
+ for (; ic < input_depth; ic++) {
+ // Load the filters
+ float32x4_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
+ }
+ local_filter_ptr += 8;
+ // Load the inputs
+ const float input_val = *local_input_ptr++;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 0, 2> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ const float* local_filter_ptr = filter_ptr;
+ const float* local_input_ptr = input_ptr;
+ int ic = 0;
+ // Handle 8 input channels at a time.
+ for (; ic <= input_depth - 8; ic += 8) {
+ // Load the filters
+ float32x4_t filter[4];
+ for (int i = 0; i < 4; i++) {
+ filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
+ }
+ local_filter_ptr += 16;
+ // Load the inputs
+ float32x4x2_t input_dup2[2];
+ for (int i = 0; i < 2; i++) {
+ const float32x4_t input = vld1q_f32(local_input_ptr + 4 * i);
+ input_dup2[i] = vzipq_f32(input, input);
+ }
+ local_input_ptr += 8;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmlaq_f32(acc[0], filter[0], input_dup2[0].val[0]);
+ acc[1] = vmlaq_f32(acc[1], filter[1], input_dup2[0].val[1]);
+ acc[2] = vmlaq_f32(acc[2], filter[2], input_dup2[1].val[0]);
+ acc[3] = vmlaq_f32(acc[3], filter[3], input_dup2[1].val[1]);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle 4 input channels at a time.
+ for (; ic <= input_depth - 4; ic += 4) {
+ // Load the filters
+ float32x2_t filter[4];
+ for (int i = 0; i < 4; i++) {
+ filter[i] = vld1_f32(local_filter_ptr + 2 * i);
+ }
+ local_filter_ptr += 8;
+ // Load the inputs
+ const float32x4_t input = vld1q_f32(local_input_ptr);
+ local_input_ptr += 4;
+ // Load the accumulators from acc_buffer
+ float32x2_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmla_lane_f32(acc[0], filter[0], vget_low_f32(input), 0);
+ acc[1] = vmla_lane_f32(acc[1], filter[1], vget_low_f32(input), 1);
+ acc[2] = vmla_lane_f32(acc[2], filter[2], vget_high_f32(input), 0);
+ acc[3] = vmla_lane_f32(acc[3], filter[3], vget_high_f32(input), 1);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ // Handle 2 input channels at a time.
+ for (; ic <= input_depth - 2; ic += 2) {
+ // Load the filters
+ const float32x4_t filter = vld1q_f32(local_filter_ptr);
+ local_filter_ptr += 4;
+ // Load the inputs
+ const float32x2_t input = vld1_f32(local_input_ptr);
+ local_input_ptr += 2;
+ // Load the accumulators from acc_buffer
+ float32x2_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmla_lane_f32(acc[0], vget_low_f32(filter), input, 0);
+ acc[1] = vmla_lane_f32(acc[1], vget_high_f32(filter), input, 1);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
+ }
+ acc_buffer_ptr += 4;
+ }
+ // Handle one input channel at a time.
+ for (; ic < input_depth; ic++) {
+ // Load the inputs
+ const float input_val = *local_input_ptr++;
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc_buffer_ptr[i] += local_filter_ptr[i] * input_val;
+ }
+ local_filter_ptr += 2;
+ acc_buffer_ptr += 2;
+ }
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 1, 8> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Load the filters
+ float32x4_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vld1q_f32(filter_ptr + 4 * i);
+ }
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the inputs
+ const float input_val = *input_ptr;
+ input_ptr += input_ptr_increment;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 1, 32> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Load the filters
+ float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0);
+ float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1);
+ float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2);
+ float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3);
+ float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4);
+ float32x4_t filter_5 = vld1q_f32(filter_ptr + 4 * 5);
+ float32x4_t filter_6 = vld1q_f32(filter_ptr + 4 * 6);
+ float32x4_t filter_7 = vld1q_f32(filter_ptr + 4 * 7);
+
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the inputs
+ const float input_val = *input_ptr;
+ input_ptr += input_ptr_increment;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0);
+ float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1);
+ float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2);
+ float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3);
+ float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4);
+ float32x4_t acc_5 = vld1q_f32(acc_buffer_ptr + 4 * 5);
+ float32x4_t acc_6 = vld1q_f32(acc_buffer_ptr + 4 * 6);
+ float32x4_t acc_7 = vld1q_f32(acc_buffer_ptr + 4 * 7);
+ // Multiply-accumulate
+ acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val);
+ acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val);
+ acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val);
+ acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val);
+ acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val);
+ acc_5 = vmlaq_n_f32(acc_5, filter_5, input_val);
+ acc_6 = vmlaq_n_f32(acc_6, filter_6, input_val);
+ acc_7 = vmlaq_n_f32(acc_7, filter_7, input_val);
+ // Store the accumulators back to acc_buffer
+ vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0);
+ vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1);
+ vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2);
+ vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3);
+ vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4);
+ vst1q_f32(acc_buffer_ptr + 4 * 5, acc_5);
+ vst1q_f32(acc_buffer_ptr + 4 * 6, acc_6);
+ vst1q_f32(acc_buffer_ptr + 4 * 7, acc_7);
+ acc_buffer_ptr += 32;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 0, 16> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ const float* local_filter_ptr = filter_ptr;
+ const float* local_input_ptr = input_ptr;
+ for (int ic = 0; ic < input_depth; ic++) {
+ // Load the filters
+ float32x4_t filter[4];
+ for (int i = 0; i < 4; i++) {
+ filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
+ }
+ local_filter_ptr += 16;
+ // Load the inputs
+ const float input_val = *local_input_ptr++;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 8, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Load the filters
+ float32x4_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vld1q_f32(filter_ptr + 4 * i);
+ }
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the inputs
+ float32x4_t input[2];
+ for (int i = 0; i < 2; i++) {
+ input[i] = vld1q_f32(input_ptr + 4 * i);
+ }
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vmlaq_f32(acc[i], input[i], filter[i]);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 2, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ float32x2_t filter = vld1_f32(filter_ptr);
+ float32x4_t filter_x4 = vcombine_f32(filter, filter);
+ int outp = 0;
+
+ // Handle two output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the inputs
+ float32x2_t input_1 = vld1_f32(input_ptr);
+ input_ptr += input_ptr_increment;
+ float32x2_t input_2 = vld1_f32(input_ptr);
+ input_ptr += input_ptr_increment;
+ float32x4_t input = vcombine_f32(input_1, input_2);
+
+ // Load the accumulators from acc_buffer
+ float32x4_t acc = vld1q_f32(acc_buffer_ptr);
+
+ // Multiply-accumulate
+ acc = vmlaq_f32(acc, input, filter_x4);
+
+ // Store the accumulators back to acc_buffer
+ vst1q_f32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the inputs
+ float32x2_t input = vld1_f32(input_ptr);
+ input_ptr += input_ptr_increment;
+
+ // Load the accumulators from acc_buffer
+ float32x2_t acc = vld1_f32(acc_buffer_ptr);
+
+ // Multiply-accumulate
+ acc = vmla_f32(acc, input, filter);
+
+ // Store the accumulators back to acc_buffer
+ vst1_f32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 2;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 4, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ float32x4_t filter = vld1q_f32(filter_ptr);
+
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the inputs
+ float32x4_t input = vld1q_f32(input_ptr);
+ // Load the accumulators from acc_buffer
+ float32x4_t acc = vld1q_f32(acc_buffer_ptr);
+ // Multiply-accumulate
+ acc = vmlaq_f32(acc, input, filter);
+ // Store the accumulators back to acc_buffer
+ vst1q_f32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+#endif
+
+// Accumulates the effect of one row of the filter, on a segment of one row
+// of the output, accessing the corresponding one row of the input.
+template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
+void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
+ const float* input_data, int pad_width,
+ int depth_multiplier, int filter_width,
+ const float* filter_data,
+ int out_x_buffer_start, int out_x_buffer_end,
+ int output_depth, float* acc_buffer) {
+#ifdef GEMMLOWP_PROFILING
+ gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
+#endif
+ // Sanity check parameters. This is important in particular to ensure
+ // that we keep the number of template instantiations minimal, so we don't
+ // increase binary size unnecessarily.
+ static_assert(kFixedDepthMultiplier || !kFixedInputDepth, "");
+ static_assert(kFixedInputDepth || kAllowStrided, "");
+ TFLITE_DCHECK(stride == 1 || kAllowStrided);
+ if (kFixedInputDepth) {
+ TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth);
+ }
+ if (kFixedDepthMultiplier) {
+ TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier);
+ }
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ const int input_ptr_increment = stride * input_depth;
+ const float* filter_base_ptr = filter_data;
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ // For the current (filter_x, filter_y) point in the filter,
+ // compute the boundaries of the corresponding output row segment.
+ int out_x_loop_start_unclampled = 0;
+ int out_x_loop_end_unclampled = 0;
+ if (kAllowStrided) {
+ if (stride == 2) {
+ out_x_loop_start_unclampled = (pad_width - filter_x + 1) / 2;
+ out_x_loop_end_unclampled =
+ (pad_width + input_width - filter_x + 1) / 2;
+ } else if (stride == 4) {
+ out_x_loop_start_unclampled = (pad_width - filter_x + 3) / 4;
+ out_x_loop_end_unclampled =
+ (pad_width + input_width - filter_x + 3) / 4;
+ } else {
+ out_x_loop_start_unclampled =
+ (pad_width - filter_x + stride - 1) / stride;
+ out_x_loop_end_unclampled =
+ (pad_width + input_width - filter_x + stride - 1) / stride;
+ }
+ } else {
+ out_x_loop_start_unclampled = pad_width - filter_x;
+ out_x_loop_end_unclampled = pad_width + input_width - filter_x;
+ }
+ // The kernel will have to iterate on the segment of the
+ // output row that starts at out_x_loop_start and out_x_loop_end.
+ const int out_x_loop_start =
+ std::max(out_x_buffer_start, out_x_loop_start_unclampled);
+ const int out_x_loop_end =
+ std::min(out_x_buffer_end, out_x_loop_end_unclampled);
+
+ float* acc_buffer_ptr =
+ acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
+ const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const float* input_ptr = input_data + in_x_origin * input_depth;
+ const int num_output_pixels = out_x_loop_end - out_x_loop_start;
+ FloatDepthwiseConvKernel<kAllowStrided, kFixedInputDepth,
+ kFixedDepthMultiplier>::Run(num_output_pixels,
+ input_depth,
+ depth_multiplier,
+ input_ptr,
+ input_ptr_increment,
+ filter_base_ptr,
+ acc_buffer_ptr);
+ filter_base_ptr += output_depth;
+ }
+}
+
+// generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized.
+inline void FloatDepthwiseConvAccumRowGeneric(
+ int stride, int input_depth, int input_width, const float* input_data,
+ int pad_width, int depth_multiplier, int filter_width,
+ const float* filter_data, int out_x_buffer_start, int out_x_buffer_end,
+ int output_depth, float* acc_buffer) {
+ gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
+#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+ LOG(FATAL)
+ << "\n\n"
+ << "*****************************************************************\n"
+ << "* This tfmini inference code was about to use the slow generic\n"
+ << "* fallback implementation for a DepthwiseConv op, and we want you\n"
+ << "* to be aware of that so that you will know why you get terrible\n"
+ << "* performance.\n"
+ << "*\n"
+ << "* If you would like to carry on with the slow code, compile\n"
+ << "* with this preprocessor token defined:\n"
+ << "* ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n"
+ << "*\n"
+ << "* The right thing to do, if you care about performance, is to add\n"
+ << "* a new DepthwiseConv kernel to tfmini to cover your case.\n"
+ << "* The relevant parameters defining your case are:\n"
+ << "* stride = " << stride << "\n"
+ << "* input_depth = " << input_depth << "\n"
+ << "* depth_multiplier = " << depth_multiplier << "\n"
+ << "*\n"
+ << "* Please do not hesitate to contact benoitjacob@ with this\n"
+ << "* information.\n"
+ << "*****************************************************************\n";
+#endif // ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#endif // TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+ const float* filter_base_ptr = filter_data;
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ const int out_x_loop_start = std::max(
+ out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
+ const int out_x_loop_end =
+ std::min(out_x_buffer_end,
+ (pad_width + input_width - filter_x + stride - 1) / stride);
+
+ float* acc_buffer_ptr =
+ acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
+ const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const float* input_ptr = input_data + in_x_origin * input_depth;
+ const int input_ptr_increment = (stride - 1) * input_depth;
+ for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
+ const float* filter_ptr = filter_base_ptr;
+ for (int ic = 0; ic < input_depth; ++ic) {
+ const float input_val = *input_ptr++;
+ for (int m = 0; m < depth_multiplier; m++) {
+ const float filter_val = *filter_ptr++;
+ *acc_buffer_ptr++ += filter_val * input_val;
+ }
+ }
+ input_ptr += input_ptr_increment;
+ }
+ filter_base_ptr += output_depth;
+ }
+}
+
+// Initializes the accumulator buffer with bias values.
+inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
+ const float* bias_data,
+ float* acc_buffer) {
+ // TODO(benoitjacob): This might need optimized specializations
+ // for small output_depth values, if that ever becomes an important
+ // case (like it was for some quantized DepthwiseConv cases).
+ for (int i = 0; i < num_output_pixels; i++) {
+ memcpy(acc_buffer + i * output_depth, bias_data,
+ sizeof(acc_buffer[0]) * output_depth);
+ }
+}
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("DepthwiseConv");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+
+ static const int kAccBufferMaxSize = 2048;
+ float acc_buffer[kAccBufferMaxSize];
+ TFLITE_DCHECK_GE(kAccBufferMaxSize, output_depth);
+ const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth;
+ const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth;
+ TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth,
+ kAccBufferActualSize);
+ TFLITE_DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize);
+ TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1);
+
+ // row_accum_func will point to the core accumulation function to be used
+ // for this DepthwiseConv op.
+ using row_accum_func_t = decltype(&FloatDepthwiseConvAccumRowGeneric);
+ row_accum_func_t row_accum_func = nullptr;
+
+#define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
+ FIXED_DEPTH_MULTIPLIER) \
+ if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
+ (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
+ depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \
+ row_accum_func = \
+ FloatDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
+ FIXED_DEPTH_MULTIPLIER>; \
+ }
+
+#ifdef USE_NEON
+ // We go over our list of kernels by decreasing order of preference
+ // for the cases where multiple kernels could apply.
+
+ // Start with the fastest kernels: AllowStrided=false, fixed input depth.
+
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1)
+
+ // Next come the strided kernels: AllowStrided=true, fixed input depth.
+ // They are a bit less efficient, but allow stride!=1.
+
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1)
+
+ // Finally, the kernels allowing a variable input depth,
+ // these are the least efficient but most general kernels.
+
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 8)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 16)
+
+#endif // USE_NEON
+
+#undef TFMINI_USE_DEPTHWISECONV_KERNEL
+
+ // No matching fast kernel found, use slow fallback.
+ if (!row_accum_func) {
+ row_accum_func = FloatDepthwiseConvAccumRowGeneric;
+ }
+
+ // Now that we have determined row_accum_func, we can start work.
+ float* output_ptr = output_data;
+ for (int b = 0; b < batches; ++b) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
+ out_x_buffer_start += kOutputPixelsInAccBuffer) {
+ const int out_x_buffer_end = std::min(
+ output_width, out_x_buffer_start + kOutputPixelsInAccBuffer);
+ // We call a 'pixel' a group of activation that share all but the
+ // 'depth'/'channel' coordinate. num_output_pixels is the number of
+ // output pixels that we will accumulate in this loop iteration.
+ const int num_output_pixels = out_x_buffer_end - out_x_buffer_start;
+ // Initialize our local accumulator with the bias values, so we don't
+ // have to add them later.
+ DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data,
+ acc_buffer);
+ // Accumulation loop. Most of the time should be spent in here.
+ for (int filter_y = filter_y_start; filter_y < filter_y_end;
+ ++filter_y) {
+ const int in_y = in_y_origin + filter_y;
+ row_accum_func(stride_width, input_depth, input_width,
+ input_data + in_y * input_dims.strides[2] +
+ b * input_dims.strides[3],
+ pad_width, depth_multiplier, filter_width,
+ filter_data + filter_y * filter_dims.strides[2],
+ out_x_buffer_start, out_x_buffer_end, output_depth,
+ acc_buffer);
+ }
+ // Finished accumulating. Now store to destination.
+ const int num_output_values = output_depth * num_output_pixels;
+ int i = 0;
+// TODO(benoitjacob) optimized code goes here
+#ifdef USE_NEON
+ // Handle 16 values at a time
+ for (; i <= num_output_values - 16; i += 16) {
+ float32x4_t acc[4];
+ for (int k = 0; k < 4; k++) {
+ acc[k] = vld1q_f32(acc_buffer + i + 4 * k);
+ }
+ for (int k = 0; k < 4; k++) {
+ acc[k] = vmaxq_f32(
+ vdupq_n_f32(output_activation_min),
+ vminq_f32(vdupq_n_f32(output_activation_max), acc[k]));
+ }
+ for (int k = 0; k < 4; k++) {
+ vst1q_f32(output_ptr + 4 * k, acc[k]);
+ }
+ output_ptr += 16;
+ }
+ // Handle 4 values at a time
+ for (; i <= num_output_values - 4; i += 4) {
+ float32x4_t acc = vld1q_f32(acc_buffer + i);
+
+ acc = vmaxq_f32(vdupq_n_f32(output_activation_min),
+ vminq_f32(vdupq_n_f32(output_activation_max), acc));
+
+ vst1q_f32(output_ptr, acc);
+ output_ptr += 4;
+ }
+#endif
+ // Handle leftover values, one by one. This is very slow.
+ for (; i < num_output_values; i++) {
+ float acc = acc_buffer[i];
+ acc = std::max(output_activation_min,
+ std::min(output_activation_max, acc));
+
+ *output_ptr++ = acc;
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, pad_width, pad_height,
+ depth_multiplier, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ float* output_data, const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, pad_width, pad_height,
+ depth_multiplier, output_data, output_dims);
+}
+
+} // namespace optimized_ops
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
new file mode 100644
index 0000000000..051ed2a2c4
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -0,0 +1,1916 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
+
+#include "fixedpoint/fixedpoint.h"
+#include "public/gemmlowp.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace optimized_ops {
+
+// Implementation of quantized DepthwiseConv
+
+template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
+struct QuantizedDepthwiseConvKernel {};
+
+#ifdef USE_NEON
+template <>
+struct QuantizedDepthwiseConvKernel<true, 8, 2> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8x2_t filter_u8;
+ filter_u8.val[0] = vld1_u8(filter_ptr);
+ filter_u8.val[1] = vld1_u8(filter_ptr + 8);
+ int16x8_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])),
+ vdupq_n_s16(filter_offset));
+ }
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x4x2_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
+ }
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += input_ptr_increment;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ // Duplicate the input values, 2-fold
+ const int16x8x2_t input_dup2 = vzipq_s16(input, input);
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[0].val[i] = vmlal_s16(acc[0].val[i], vget_low_s16(filter[i]),
+ vget_low_s16(input_dup2.val[i]));
+ acc[1].val[i] = vmlal_s16(acc[1].val[i], vget_high_s16(filter[i]),
+ vget_high_s16(input_dup2.val[i]));
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
+ vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 8, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
+ const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
+ const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+
+ int outp = 0;
+ // Handle 2 output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8[2];
+ for (int i = 0; i < 2; i++) {
+ input_u8[i] = vld1_u8(input_ptr + 8 * i);
+ }
+ input_ptr += 16;
+ int16x8_t input[2];
+ for (int i = 0; i < 2; i++) {
+ input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
+ }
+ for (int i = 0; i < 2; i++) {
+ input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
+ }
+ // Multiply-accumulate.
+ acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input[0]));
+ acc[1] =
+ vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input[0]));
+ acc[2] = vmlal_s16(acc[2], vget_low_s16(filter), vget_low_s16(input[1]));
+ acc[3] =
+ vmlal_s16(acc[3], vget_high_s16(filter), vget_high_s16(input[1]));
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle 1 output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc[2];
+ acc[0] = vld1q_s32(acc_buffer_ptr);
+ acc[1] = vld1q_s32(acc_buffer_ptr + 4);
+
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ // Multiply-accumulate.
+ acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input));
+ acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input));
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr, acc[0]);
+ vst1q_s32(acc_buffer_ptr + 4, acc[1]);
+ acc_buffer_ptr += 8;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 4, 2> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
+ const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
+ const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+
+ int outp = 0;
+ // Handle 2 output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ // Duplicate the input values, 2-fold
+ const int16x8x2_t input_dup2 = vzipq_s16(input, input);
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(filter),
+ vget_low_s16(input_dup2.val[i]));
+ acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(filter),
+ vget_high_s16(input_dup2.val[i]));
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
+ input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
+ input_ptr += 4;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+ // Duplicate the input values, 2-fold
+ const int16x4x2_t input_dup2 = vzip_s16(input, input);
+ // Multiply-accumulate
+ acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), input_dup2.val[0]);
+ acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), input_dup2.val[1]);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 2, 8> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ int16x8_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i);
+ const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
+ filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+ }
+ int outp = 0;
+ // Handle two output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc[8];
+ for (int i = 0; i < 8; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
+ input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
+ input_ptr += 4;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+ // Multiply-accumulate.
+ acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
+ acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0);
+ acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1);
+ acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1);
+ acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), input, 2);
+ acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), input, 2);
+ acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), input, 3);
+ acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), input, 3);
+ // Store the accumulators back to acc_buffer.
+ for (int i = 0; i < 8; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 32;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_ptr += 2;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+
+ // Multiply-accumulate.
+ acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
+ acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0);
+ acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1);
+ acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1);
+
+ // Store the accumulators back to acc_buffer.
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 2, 2> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8 = vdup_n_u8(0);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
+ filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
+ filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
+ const int16x4_t filter_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
+ const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
+
+ int outp = 0;
+ // Handle 4 output pixels at a time.
+ for (; outp <= num_output_pixels - 4; outp += 4) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ // Duplicate the input values, 2-fold
+ const int16x8x2_t input_dup2 = vzipq_s16(input, input);
+ // Multiply-accumulate
+ acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0]));
+ acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0]));
+ acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1]));
+ acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1]));
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc = vld1q_s32(acc_buffer_ptr);
+
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_ptr += 2;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+ // Duplicate the input values, 2-fold
+ const int16x4_t input_dup2 = vzip_s16(input, input).val[0];
+ // Multiply-accumulate
+ acc = vmlal_s16(acc, filter, input_dup2);
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 2, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8 = vdup_n_u8(0);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
+ const int16x4_t filter_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
+ const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
+
+ int outp = 0;
+ // Handle 8 output pixels at a time.
+ for (; outp <= num_output_pixels - 8; outp += 8) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8[2];
+ for (int i = 0; i < 2; i++) {
+ input_u8[i] = vld1_u8(input_ptr + 8 * i);
+ }
+ input_ptr += 16;
+ int16x8_t input[2];
+ for (int i = 0; i < 2; i++) {
+ input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
+ }
+ for (int i = 0; i < 2; i++) {
+ input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
+ }
+
+ // Multiply-accumulate.
+ acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input[0]));
+ acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input[0]));
+ acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input[1]));
+ acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input[1]));
+ // Store the accumulators back to acc_buffer.
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle 4 output pixels at a time.
+ for (; outp <= num_output_pixels - 4; outp += 4) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+
+ // Multiply-accumulate.
+ acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input));
+ acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input));
+ // Store the accumulators back to acc_buffer.
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ // Handle 2 output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc = vld1q_s32(acc_buffer_ptr);
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
+ input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
+ input_ptr += 4;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+
+ // Multiply-accumulate.
+ acc = vmlal_s16(acc, filter, input);
+ // Store the accumulators back to acc_buffer.
+ vst1q_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+ // Handle 1 output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer.
+ int32x2_t acc = vld1_s32(acc_buffer_ptr);
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_ptr += 2;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+
+ // Multiply-accumulate.
+ acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input));
+ // Store the accumulators back to acc_buffer.
+ vst1_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 2;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 1, 2> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8 = vdup_n_u8(0);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
+ const int16x4_t filter_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
+ const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
+
+ int outp = 0;
+ // Handle 8 output pixels at a time.
+ for (; outp <= num_output_pixels - 8; outp += 8) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ // Duplicate the input values, 2-fold
+ const int16x8x2_t input_dup2 = vzipq_s16(input, input);
+ // Multiply-accumulate
+ acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0]));
+ acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0]));
+ acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1]));
+ acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1]));
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x2_t acc = vld1_s32(acc_buffer_ptr);
+
+ // Load the inputs, add input_offset.
+ const uint32 input = *input_ptr++ + input_offset;
+
+ // Multiply-accumulate
+ acc = vget_low_s32(vmlal_n_s16(vcombine_s32(acc, acc), filter, input));
+ // Store the accumulators back to acc_buffer
+ vst1_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 2;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 1, 4> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8 = vdup_n_u8(0);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
+ filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
+ filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
+ const int16x4_t filter_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
+ const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
+
+ int outp = 0;
+ // Handle 8 output pixels at a time.
+ for (; outp <= num_output_pixels - 8; outp += 8) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[8];
+ for (int i = 0; i < 8; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+
+ // Multiply-accumulate
+ acc[0] = vmlal_lane_s16(acc[0], filter, vget_low_s16(input), 0);
+ acc[1] = vmlal_lane_s16(acc[1], filter, vget_low_s16(input), 1);
+ acc[2] = vmlal_lane_s16(acc[2], filter, vget_low_s16(input), 2);
+ acc[3] = vmlal_lane_s16(acc[3], filter, vget_low_s16(input), 3);
+ acc[4] = vmlal_lane_s16(acc[4], filter, vget_high_s16(input), 0);
+ acc[5] = vmlal_lane_s16(acc[5], filter, vget_high_s16(input), 1);
+ acc[6] = vmlal_lane_s16(acc[6], filter, vget_high_s16(input), 2);
+ acc[7] = vmlal_lane_s16(acc[7], filter, vget_high_s16(input), 3);
+
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 8; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 32;
+ }
+ // Handle 4 output pixels at a time.
+ for (; outp <= num_output_pixels - 4; outp += 4) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
+ input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
+ input_ptr += 4;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+
+ // Multiply-accumulate
+ acc[0] = vmlal_lane_s16(acc[0], filter, input, 0);
+ acc[1] = vmlal_lane_s16(acc[1], filter, input, 1);
+ acc[2] = vmlal_lane_s16(acc[2], filter, input, 2);
+ acc[3] = vmlal_lane_s16(acc[3], filter, input, 3);
+
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc = vld1q_s32(acc_buffer_ptr);
+
+ // Load the inputs, add input_offset.
+ const uint32 input = *input_ptr++ + input_offset;
+
+ // Multiply-accumulate
+ acc = vmlal_n_s16(acc, filter, input);
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 4, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8 = vdup_n_u8(0);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
+ filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
+ filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
+ const int16x4_t filter_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
+ const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
+
+ int outp = 0;
+ // Handle 4 output pixels at a time.
+ for (; outp <= num_output_pixels - 4; outp += 4) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ int16x8_t input[2];
+ for (int i = 0; i < 2; i++) {
+ const uint8x8_t input_u8 = vld1_u8(input_ptr + 8 * i);
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ input[i] = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ }
+ input_ptr += 16;
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[2 * i + 0] =
+ vmlal_s16(acc[2 * i + 0], filter, vget_low_s16(input[i]));
+ acc[2 * i + 1] =
+ vmlal_s16(acc[2 * i + 1], filter, vget_high_s16(input[i]));
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc;
+ acc = vld1q_s32(acc_buffer_ptr);
+
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
+ input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
+ input_ptr += 4;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+ // Multiply-accumulate
+ acc = vmlal_s16(acc, filter, input);
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 4, 4> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ int16x8_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i);
+ const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
+ filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+ }
+
+ int outp = 0;
+ // Handle 2 output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[8];
+ for (int i = 0; i < 8; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+
+ // Multiply-accumulate
+ acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]),
+ vget_low_s16(input), 0);
+ acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]),
+ vget_low_s16(input), 1);
+ acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]),
+ vget_low_s16(input), 2);
+ acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]),
+ vget_low_s16(input), 3);
+ acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]),
+ vget_high_s16(input), 0);
+ acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]),
+ vget_high_s16(input), 1);
+ acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]),
+ vget_high_s16(input), 2);
+ acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]),
+ vget_high_s16(input), 3);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 8; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 32;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
+ input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
+ input_ptr += 4;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+
+ // Multiply-accumulate
+ acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
+ acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 1);
+ acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 2);
+ acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 3);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 0, 3> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // We will have to duplicate bytes in a NEON register, 3-fold.
+ // We will do that by register-level table-look-up using VTBL instructions.
+ // Here we prepare the registers containing the table-lookup indices.
+ static const uint8 dup3_indices_array[3][8] = {{0, 0, 0, 1, 1, 1, 2, 2},
+ {2, 3, 3, 3, 4, 4, 4, 5},
+ {5, 5, 6, 6, 6, 7, 7, 7}};
+ uint8x8_t dup3_indices[3];
+ for (int i = 0; i < 3; i++) {
+ dup3_indices[i] = vld1_u8(dup3_indices_array[i]);
+ }
+
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ const uint8* local_filter_ptr = filter_ptr;
+ const uint8* local_input_ptr = input_ptr;
+ int ic = 0;
+ // Handle 8 input channels at a time.
+ for (; ic <= input_depth - 8; ic += 8) {
+ // Load the filters, add filter_offset.
+ int16x8_t filter[3];
+ uint8x8x3_t filter_u8;
+ filter_u8.val[0] = vld1_u8(local_filter_ptr);
+ filter_u8.val[1] = vld1_u8(local_filter_ptr + 8);
+ filter_u8.val[2] = vld1_u8(local_filter_ptr + 16);
+ local_filter_ptr += 24;
+ for (int i = 0; i < 3; i++) {
+ const int16x8_t filter_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i]));
+ filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+ }
+ // Load the inputs, duplicate 3-fold, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
+ local_input_ptr += 8;
+
+ uint8x8_t input_u8_dup3[3];
+ for (int i = 0; i < 3; i++) {
+ input_u8_dup3[i] = vtbl1_u8(input_u8, dup3_indices[i]);
+ }
+ int16x8_t input_dup3[3];
+ for (int i = 0; i < 3; i++) {
+ const int16x8_t input_s16_dup3 =
+ vreinterpretq_s16_u16(vmovl_u8(input_u8_dup3[i]));
+ input_dup3[i] = vaddq_s16(input_s16_dup3, vdupq_n_s16(input_offset));
+ }
+ // Load the accumulators from acc_buffer
+ int32x4x3_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
+ acc[i].val[2] = vld1q_s32(acc_buffer_ptr + 4 * i + 16);
+ }
+ // Multiply-accumulate
+ for (int j = 0; j < 3; j++) {
+ acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(input_dup3[j]),
+ vget_low_s16(filter[j]));
+ acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(input_dup3[j]),
+ vget_high_s16(filter[j]));
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
+ vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
+ vst1q_s32(acc_buffer_ptr + 4 * i + 16, acc[i].val[2]);
+ }
+ acc_buffer_ptr += 24;
+ }
+ // Handle one input channel at a time.
+ for (; ic < input_depth; ic++) {
+ const int16 input_val = *local_input_ptr++ + input_offset;
+ for (int i = 0; i < 3; i++) {
+ const int16 filter_val = local_filter_ptr[i] + filter_offset;
+ *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
+ }
+ local_filter_ptr += 3;
+ }
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 0, 2> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ const uint8* local_filter_ptr = filter_ptr;
+ const uint8* local_input_ptr = input_ptr;
+ int ic = 0;
+ // Handle 8 input channels at a time.
+ for (; ic <= input_depth - 8; ic += 8) {
+ // Load the filters, add filter_offset.
+ int16x8_t filter[2];
+ uint8x8x2_t filter_u8;
+ filter_u8.val[0] = vld1_u8(local_filter_ptr);
+ filter_u8.val[1] = vld1_u8(local_filter_ptr + 8);
+ local_filter_ptr += 16;
+ for (int i = 0; i < 2; i++) {
+ const int16x8_t filter_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i]));
+ filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+ }
+ // Load the inputs, add input_offset, duplicate 2-fold.
+ const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
+ local_input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ const int16x8x2_t input_dup2 = vzipq_s16(input, input);
+ // Load the accumulators from acc_buffer.
+ int32x4x2_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
+ }
+ // Multiply-accumulate.
+ for (int j = 0; j < 2; j++) {
+ acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(filter[j]),
+ vget_low_s16(input_dup2.val[j]));
+ acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(filter[j]),
+ vget_high_s16(input_dup2.val[j]));
+ }
+ // Store the accumulators back to acc_buffer.
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
+ vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one input channel at a time.
+ for (; ic < input_depth; ic++) {
+ // Load the inputs.
+ const int16 input_val = *local_input_ptr++ + input_offset;
+ for (int i = 0; i < 2; i++) {
+ const int16 filter_val = local_filter_ptr[i] + filter_offset;
+ *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
+ }
+ local_filter_ptr += 2;
+ }
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 0, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ const uint8* local_filter_ptr = filter_ptr;
+ const uint8* local_input_ptr = input_ptr;
+ int ic = 0;
+ // Handle 16 input channels at a time.
+ for (; ic <= input_depth - 16; ic += 16) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8_0 = vld1_u8(local_filter_ptr + 8 * 0);
+ uint8x8_t filter_u8_1 = vld1_u8(local_filter_ptr + 8 * 1);
+ local_filter_ptr += 16;
+ int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
+ int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
+ filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset));
+ filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset));
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8_0 = vld1_u8(local_input_ptr + 8 * 0);
+ uint8x8_t input_u8_1 = vld1_u8(local_input_ptr + 8 * 1);
+ local_input_ptr += 16;
+ int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0));
+ int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1));
+ input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset));
+ input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset));
+ // Load the accumulators from acc_buffer
+ int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
+ int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
+ int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
+ int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
+ acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), vget_low_s16(filter_0));
+ acc_1 =
+ vmlal_s16(acc_1, vget_high_s16(input_0), vget_high_s16(filter_0));
+ acc_2 = vmlal_s16(acc_2, vget_low_s16(input_1), vget_low_s16(filter_1));
+ acc_3 =
+ vmlal_s16(acc_3, vget_high_s16(input_1), vget_high_s16(filter_1));
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
+ vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
+ vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
+ vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
+ acc_buffer_ptr += 16;
+ }
+ // Handle 8 input channels at a time.
+ for (; ic <= input_depth - 8; ic += 8) {
+ // Load the filters, add filter_offset.
+ const uint8x8_t filter_u8 = vld1_u8(local_filter_ptr);
+ local_filter_ptr += 8;
+ const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
+ const int16x8_t filter =
+ vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
+ local_input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter));
+ acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter));
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ // Handle one input channel at a time.
+ for (; ic < input_depth; ic++) {
+ const int16 input_val = *local_input_ptr++ + input_offset;
+ const int16 filter_val = *local_filter_ptr++ + filter_offset;
+ *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
+ }
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 16, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8[2];
+ for (int i = 0; i < 2; i++) {
+ filter_u8[i] = vld1_u8(filter_ptr + 8 * i);
+ }
+ int16x8_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i]));
+ }
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset));
+ }
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8[2];
+ for (int i = 0; i < 2; i++) {
+ input_u8[i] = vld1_u8(input_ptr + 8 * i);
+ }
+ input_ptr += input_ptr_increment;
+ int16x8_t input[2];
+ for (int i = 0; i < 2; i++) {
+ input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
+ }
+ for (int i = 0; i < 2; i++) {
+ input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
+ }
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(input[i]),
+ vget_low_s16(filter[i]));
+ acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(input[i]),
+ vget_high_s16(filter[i]));
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 8, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
+ const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
+ const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(input_ptr);
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter));
+ acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter));
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 1, 16> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8[2];
+ for (int i = 0; i < 2; i++) {
+ filter_u8[i] = vld1_u8(filter_ptr + 8 * i);
+ }
+ int16x8_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i]));
+ }
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset));
+ }
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ uint8 input_u8 = *input_ptr;
+ input_ptr += input_ptr_increment;
+ int16 input = static_cast<int16>(input_u8 + input_offset);
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[2 * i + 0] =
+ vmlal_n_s16(acc[2 * i + 0], vget_low_s16(filter[i]), input);
+ acc[2 * i + 1] =
+ vmlal_n_s16(acc[2 * i + 1], vget_high_s16(filter[i]), input);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 1, 32> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0);
+ uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1);
+ uint8x8_t filter_u8_2 = vld1_u8(filter_ptr + 8 * 2);
+ uint8x8_t filter_u8_3 = vld1_u8(filter_ptr + 8 * 3);
+ int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
+ int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
+ int16x8_t filter_2 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_2));
+ int16x8_t filter_3 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_3));
+ filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset));
+ filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset));
+ filter_2 = vaddq_s16(filter_2, vdupq_n_s16(filter_offset));
+ filter_3 = vaddq_s16(filter_3, vdupq_n_s16(filter_offset));
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ uint8 input_u8 = *input_ptr;
+ input_ptr += input_ptr_increment;
+ int16 input = static_cast<int16>(input_u8 + input_offset);
+ // Load the accumulators from acc_buffer
+ int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
+ int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
+ int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
+ int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
+ int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4);
+ int32x4_t acc_5 = vld1q_s32(acc_buffer_ptr + 4 * 5);
+ int32x4_t acc_6 = vld1q_s32(acc_buffer_ptr + 4 * 6);
+ int32x4_t acc_7 = vld1q_s32(acc_buffer_ptr + 4 * 7);
+ // Multiply-accumulate
+ acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input);
+ acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input);
+ acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input);
+ acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input);
+ acc_4 = vmlal_n_s16(acc_4, vget_low_s16(filter_2), input);
+ acc_5 = vmlal_n_s16(acc_5, vget_high_s16(filter_2), input);
+ acc_6 = vmlal_n_s16(acc_6, vget_low_s16(filter_3), input);
+ acc_7 = vmlal_n_s16(acc_7, vget_high_s16(filter_3), input);
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
+ vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
+ vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
+ vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
+ vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4);
+ vst1q_s32(acc_buffer_ptr + 4 * 5, acc_5);
+ vst1q_s32(acc_buffer_ptr + 4 * 6, acc_6);
+ vst1q_s32(acc_buffer_ptr + 4 * 7, acc_7);
+ acc_buffer_ptr += 32;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 1, 8> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
+ const int16x8_t filter = vaddq_s16(
+ vreinterpretq_s16_u16(vmovl_u8(filter_u8)), vdupq_n_s16(filter_offset));
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ uint8 input_u8 = *input_ptr;
+ input_ptr += input_ptr_increment;
+ int16 input = static_cast<int16>(input_u8 + input_offset);
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmlal_n_s16(acc[0], vget_low_s16(filter), input);
+ acc[1] = vmlal_n_s16(acc[1], vget_high_s16(filter), input);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 2, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8 = vdup_n_u8(0);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
+ const int16x4_t filter_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
+ const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
+
+ int outp = 0;
+
+ // Handle 2 output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc = vld1q_s32(acc_buffer_ptr);
+ // Load the inputs, add input_offset.
+ uint16x4_t input_u16 = vdup_n_u16(0);
+ input_u16 = vset_lane_u16((reinterpret_cast<const uint16*>(input_ptr))[0],
+ input_u16, 0);
+ input_ptr += input_ptr_increment;
+ input_u16 = vset_lane_u16((reinterpret_cast<const uint16*>(input_ptr))[0],
+ input_u16, 1);
+ input_ptr += input_ptr_increment;
+ const int16x4_t input_s16 = vreinterpret_s16_u16(
+ vget_low_u16(vmovl_u8(vreinterpret_u8_u16(input_u16))));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+
+ // Multiply-accumulate.
+ acc = vmlal_s16(acc, filter, input);
+ // Store the accumulators back to acc_buffer.
+ vst1q_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+
+ // Handle 1 output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer.
+ int32x2_t acc = vld1_s32(acc_buffer_ptr);
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_ptr += input_ptr_increment;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+
+ // Multiply-accumulate.
+ acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input));
+ // Store the accumulators back to acc_buffer.
+ vst1_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 2;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 4, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ if (num_output_pixels <= 0) {
+ return;
+ }
+
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8 = vdup_n_u8(0);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
+ filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
+ filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
+ const int16x4_t filter_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
+ const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
+
+ int outp = 0;
+
+ // Handle one output pixel at a time until second to the last pixel. Second
+ // to the last because we read eight input pixels while only processing
+ // four.
+ for (; outp < num_output_pixels - 1; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc;
+ acc = vld1q_s32(acc_buffer_ptr);
+
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += input_ptr_increment;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+ // Multiply-accumulate
+ acc = vmlal_s16(acc, filter, input);
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+
+ // Handle the last output pixel.
+ // Load the accumulators from acc_buffer
+ int32x4_t acc;
+ acc = vld1q_s32(acc_buffer_ptr);
+
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
+ input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+ // Multiply-accumulate
+ acc = vmlal_s16(acc, filter, input);
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr, acc);
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 12, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8_0 = vld1_u8(filter_ptr);
+ uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 4);
+ int16x8_t filter_s16_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
+ int16x8_t filter_s16_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
+ filter_s16_0 = vaddq_s16(filter_s16_0, vdupq_n_s16(filter_offset));
+ filter_s16_1 = vaddq_s16(filter_s16_1, vdupq_n_s16(filter_offset));
+ int16x4_t filter_0 = vget_low_s16(filter_s16_0);
+ int16x4_t filter_1 = vget_high_s16(filter_s16_0);
+ int16x4_t filter_2 = vget_high_s16(filter_s16_1);
+
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8_0 = vld1_u8(input_ptr);
+ uint8x8_t input_u8_1 = vld1_u8(input_ptr + 4);
+ input_ptr += input_ptr_increment;
+ int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0));
+ int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1));
+ input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset));
+ input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset));
+
+ // Load the accumulators from acc_buffer
+ int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
+ int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
+ int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
+
+ // Multiply-accumulate
+ acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), filter_0);
+ acc_1 = vmlal_s16(acc_1, vget_high_s16(input_0), filter_1);
+ acc_2 = vmlal_s16(acc_2, vget_high_s16(input_1), filter_2);
+
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
+ vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
+ vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
+
+ acc_buffer_ptr += 12;
+ }
+ }
+};
+#endif
+
+// Accumulates the effect of one row of the filter, on a segment of one row
+// of the output, accessing the corresponding one row of the input.
+template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
+void QuantizedDepthwiseConvAccumRow(
+ int stride, int input_depth, int input_width, const uint8* input_data,
+ int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
+ const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
+ int out_x_buffer_end, int output_depth, int32* acc_buffer) {
+#ifdef GEMMLOWP_PROFILING
+ gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
+#endif
+ // Sanity check parameters. This is important in particular to ensure
+ // that we keep the number of template instantiations minimal, so we don't
+ // increase binary size unnecessarily.
+ static_assert(kFixedDepthMultiplier || !kFixedInputDepth, "");
+ static_assert(kFixedInputDepth || kAllowStrided, "");
+ TFLITE_DCHECK(stride == 1 || kAllowStrided);
+ if (kFixedInputDepth) {
+ TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth);
+ }
+ if (kFixedDepthMultiplier) {
+ TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier);
+ }
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ const int input_ptr_increment = stride * input_depth;
+ const uint8* filter_base_ptr = filter_data;
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ // For the current (filter_x, filter_y) point in the filter,
+ // compute the boundaries of the corresponding output row segment.
+ int out_x_loop_start_unclampled = 0;
+ int out_x_loop_end_unclampled = 0;
+ if (kAllowStrided) {
+ if (stride == 2) {
+ out_x_loop_start_unclampled = (pad_width - filter_x + 1) / 2;
+ out_x_loop_end_unclampled =
+ (pad_width + input_width - filter_x + 1) / 2;
+ } else if (stride == 4) {
+ out_x_loop_start_unclampled = (pad_width - filter_x + 3) / 4;
+ out_x_loop_end_unclampled =
+ (pad_width + input_width - filter_x + 3) / 4;
+ } else {
+ out_x_loop_start_unclampled =
+ (pad_width - filter_x + stride - 1) / stride;
+ out_x_loop_end_unclampled =
+ (pad_width + input_width - filter_x + stride - 1) / stride;
+ }
+ } else {
+ out_x_loop_start_unclampled = pad_width - filter_x;
+ out_x_loop_end_unclampled = pad_width + input_width - filter_x;
+ }
+ // The kernel will have to iterate on the segment of the
+ // output row that starts at out_x_loop_start and out_x_loop_end.
+ const int out_x_loop_start =
+ std::max(out_x_buffer_start, out_x_loop_start_unclampled);
+ const int out_x_loop_end =
+ std::min(out_x_buffer_end, out_x_loop_end_unclampled);
+
+ int32* acc_buffer_ptr =
+ acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
+ const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const uint8* input_ptr = input_data + in_x_origin * input_depth;
+ const int num_output_pixels = out_x_loop_end - out_x_loop_start;
+ QuantizedDepthwiseConvKernel<
+ kAllowStrided, kFixedInputDepth,
+ kFixedDepthMultiplier>::Run(num_output_pixels, input_depth,
+ depth_multiplier, input_ptr, input_offset,
+ input_ptr_increment, filter_base_ptr,
+ filter_offset, acc_buffer_ptr);
+ filter_base_ptr += output_depth;
+ }
+}
+
+// generic fallback of DepthwiseConvAccumRow, portable, non-templatized.
+inline void QuantizedDepthwiseConvAccumRowGeneric(
+ int stride, int input_depth, int input_width, const uint8* input_data,
+ int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
+ const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
+ int out_x_buffer_end, int output_depth, int32* acc_buffer) {
+ gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
+#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+ LOG(FATAL)
+ << "\n\n"
+ << "*****************************************************************\n"
+ << "* This tfmini inference code was about to use the slow generic\n"
+ << "* fallback implementation for a DepthwiseConv op, and we want you\n"
+ << "* to be aware of that so that you will know why you get terrible\n"
+ << "* performance.\n"
+ << "*\n"
+ << "* If you would like to carry on with the slow code, compile\n"
+ << "* with this preprocessor token defined:\n"
+ << "* TFLITE_ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n"
+ << "*\n"
+ << "* The right thing to do, if you care about performance, is to add\n"
+ << "* a new DepthwiseConv kernel to tfmini to cover your case.\n"
+ << "* The relevant parameters defining your case are:\n"
+ << "* stride = " << stride << "\n"
+ << "* input_depth = " << input_depth << "\n"
+ << "* depth_multiplier = " << depth_multiplier << "\n"
+ << "*\n"
+ << "* Please do not hesitate to contact benoitjacob@ with this\n"
+ << "* information.\n"
+ << "*****************************************************************\n";
+#endif // ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#endif // TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+ const uint8* filter_base_ptr = filter_data;
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ const int out_x_loop_start = std::max(
+ out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
+ const int out_x_loop_end =
+ std::min(out_x_buffer_end,
+ (pad_width + input_width - filter_x + stride - 1) / stride);
+
+ int32* acc_buffer_ptr =
+ acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
+ const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const uint8* input_ptr = input_data + in_x_origin * input_depth;
+ const int input_ptr_increment = (stride - 1) * input_depth;
+ for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
+ const uint8* filter_ptr = filter_base_ptr;
+ for (int ic = 0; ic < input_depth; ++ic) {
+ const int16 input_val = *input_ptr++ + input_offset;
+ for (int m = 0; m < depth_multiplier; m++) {
+ const int16 filter_val = *filter_ptr++ + filter_offset;
+ *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
+ }
+ }
+ input_ptr += input_ptr_increment;
+ }
+ filter_base_ptr += output_depth;
+ }
+}
+
+// Initializes the accumulator buffer with bias values.
+inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
+ const int32* bias_data,
+ int32* acc_buffer) {
+ int i = 0;
+#ifdef USE_NEON
+ if (output_depth == 1) {
+ const int32x4_t b = vdupq_n_s32(bias_data[0]);
+ for (; i <= num_output_pixels - 16; i += 16) {
+ vst1q_s32(acc_buffer + i + 0, b);
+ vst1q_s32(acc_buffer + i + 4, b);
+ vst1q_s32(acc_buffer + i + 8, b);
+ vst1q_s32(acc_buffer + i + 12, b);
+ }
+ for (; i <= num_output_pixels - 4; i += 4) {
+ vst1q_s32(acc_buffer + i, b);
+ }
+ } else if (output_depth == 2) {
+ int32x4_t b = vdupq_n_s32(bias_data[0]);
+ b = vsetq_lane_s32(bias_data[1], b, 1);
+ b = vsetq_lane_s32(bias_data[1], b, 3);
+ for (; i <= num_output_pixels - 8; i += 8) {
+ vst1q_s32(acc_buffer + 2 * i + 0, b);
+ vst1q_s32(acc_buffer + 2 * i + 4, b);
+ vst1q_s32(acc_buffer + 2 * i + 8, b);
+ vst1q_s32(acc_buffer + 2 * i + 12, b);
+ }
+ for (; i <= num_output_pixels - 2; i += 2) {
+ vst1q_s32(acc_buffer + 2 * i, b);
+ }
+ } else if (output_depth == 4) {
+ const int32x4_t b = vld1q_s32(bias_data);
+ for (; i <= num_output_pixels - 4; i += 4) {
+ vst1q_s32(acc_buffer + 4 * i + 0, b);
+ vst1q_s32(acc_buffer + 4 * i + 4, b);
+ vst1q_s32(acc_buffer + 4 * i + 8, b);
+ vst1q_s32(acc_buffer + 4 * i + 12, b);
+ }
+ for (; i < num_output_pixels; i++) {
+ vst1q_s32(acc_buffer + 4 * i, b);
+ }
+ } else if (output_depth == 8) {
+ const int32x4_t b0 = vld1q_s32(bias_data);
+ const int32x4_t b1 = vld1q_s32(bias_data + 4);
+ for (; i <= num_output_pixels - 2; i += 2) {
+ vst1q_s32(acc_buffer + 8 * i + 0, b0);
+ vst1q_s32(acc_buffer + 8 * i + 4, b1);
+ vst1q_s32(acc_buffer + 8 * i + 8, b0);
+ vst1q_s32(acc_buffer + 8 * i + 12, b1);
+ }
+ for (; i < num_output_pixels; i++) {
+ vst1q_s32(acc_buffer + 8 * i + 0, b0);
+ vst1q_s32(acc_buffer + 8 * i + 4, b1);
+ }
+ } else if (output_depth == 16) {
+ const int32x4_t b0 = vld1q_s32(bias_data);
+ const int32x4_t b1 = vld1q_s32(bias_data + 4);
+ const int32x4_t b2 = vld1q_s32(bias_data + 8);
+ const int32x4_t b3 = vld1q_s32(bias_data + 12);
+ for (; i < num_output_pixels; i++) {
+ vst1q_s32(acc_buffer + 16 * i + 0, b0);
+ vst1q_s32(acc_buffer + 16 * i + 4, b1);
+ vst1q_s32(acc_buffer + 16 * i + 8, b2);
+ vst1q_s32(acc_buffer + 16 * i + 12, b3);
+ }
+ }
+#endif
+ for (; i < num_output_pixels; i++) {
+ memcpy(acc_buffer + i * output_depth, bias_data,
+ sizeof(acc_buffer[0]) * output_depth);
+ }
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+
+ static const int kAccBufferMaxSize = 2048;
+ int32 acc_buffer[kAccBufferMaxSize];
+ TFLITE_DCHECK_GE(kAccBufferMaxSize, output_depth);
+ const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth;
+ const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth;
+ TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth,
+ kAccBufferActualSize);
+ TFLITE_DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize);
+ TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1);
+
+ // row_accum_func will point to the core accumulation function to be used
+ // for this DepthwiseConv op.
+ using row_accum_func_t = decltype(&QuantizedDepthwiseConvAccumRowGeneric);
+ row_accum_func_t row_accum_func = nullptr;
+
+#define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
+ FIXED_DEPTH_MULTIPLIER) \
+ if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
+ (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
+ depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \
+ row_accum_func = \
+ QuantizedDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
+ FIXED_DEPTH_MULTIPLIER>; \
+ }
+
+#ifdef USE_NEON
+ // We go over our list of kernels by decreasing order of preference
+ // for the cases where multiple kernels could apply.
+
+ // Start with the fastest kernels: AllowStrided=false, fixed input depth.
+
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 2)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 2)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 2)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 4)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 4)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 8)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 12, 1)
+
+ // Next come the strided kernels: AllowStrided=true, fixed input depth.
+ // They are a bit less efficient, but allow stride!=1.
+
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 2)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 16, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 16)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1)
+
+ // Finally, the kernels allowing a variable input depth,
+ // these are the least efficient but most general kernels.
+
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 3)
+#endif // USE_NEON
+
+ // No matching fast kernel found, use slow fallback.
+ if (!row_accum_func) {
+ row_accum_func = QuantizedDepthwiseConvAccumRowGeneric;
+ }
+
+#undef TFMINI_USE_DEPTHWISECONV_KERNEL
+
+ // Now that we have determined row_accum_func, we can start work.
+ uint8* output_ptr = output_data;
+ for (int b = 0; b < batches; ++b) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
+ out_x_buffer_start += kOutputPixelsInAccBuffer) {
+ const int out_x_buffer_end = std::min(
+ output_width, out_x_buffer_start + kOutputPixelsInAccBuffer);
+ // We call a 'pixel' a group of activation that share all but the
+ // 'depth'/'channel' coordinate. num_output_pixels is the number of
+ // output pixels that we will accumulate in this loop iteration.
+ const int num_output_pixels = out_x_buffer_end - out_x_buffer_start;
+ // Initialize our local accumulator with the bias values, so we don't
+ // have to add them later.
+ DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data,
+ acc_buffer);
+ // Accumulation loop. Most of the time should be spent in here.
+ for (int filter_y = filter_y_start; filter_y < filter_y_end;
+ ++filter_y) {
+ const int in_y = in_y_origin + filter_y;
+ row_accum_func(
+ stride_width, input_depth, input_width,
+ input_data + in_y * input_dims.strides[2] +
+ b * input_dims.strides[3],
+ input_offset, pad_width, depth_multiplier, filter_width,
+ filter_data + filter_y * filter_dims.strides[2], filter_offset,
+ out_x_buffer_start, out_x_buffer_end, output_depth, acc_buffer);
+ }
+ // Finished accumulating int32 values. Now need to convert them to
+ // the final 8bit form and store them.
+ gemmlowp::ScopedProfilingLabel label("downquantize+store");
+ const int num_output_values = output_depth * num_output_pixels;
+ int i = 0;
+#ifdef USE_NEON
+ using gemmlowp::RoundingDivideByPOT;
+ const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
+ const int32x4_t output_activation_min_vec =
+ vdupq_n_s32(output_activation_min);
+ const int32x4_t output_activation_max_vec =
+ vdupq_n_s32(output_activation_max);
+ // Handle 16 values at once.
+ // This allows us to issue 4 mutually independent int32
+ // multiplications (vqrdmulh), which should alleviate most of their
+ // high latency.
+ for (; i <= num_output_values - 16; i += 16) {
+ int32x4_t acc[4];
+ for (int j = 0; j < 4; j++) {
+ acc[j] = vld1q_s32(acc_buffer + i + 4 * j);
+ }
+
+ // Fixed-point multiplication.
+ for (int j = 0; j < 4; j++) {
+ acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
+ }
+ for (int j = 0; j < 4; j++) {
+ acc[j] = RoundingDivideByPOT(acc[j], output_shift);
+ }
+ // Add the output offset.
+ for (int j = 0; j < 4; j++) {
+ acc[j] = vaddq_s32(acc[j], output_offset_vec);
+ }
+ // Apply the activation function.
+ for (int j = 0; j < 4; j++) {
+ acc[j] = vmaxq_s32(acc[j], output_activation_min_vec);
+ }
+ for (int j = 0; j < 4; j++) {
+ acc[j] = vminq_s32(acc[j], output_activation_max_vec);
+ }
+ // Saturating cast to uint8 and store to destination.
+ int16x4_t acc_s16[4];
+ for (int j = 0; j < 4; j++) {
+ acc_s16[j] = vqmovn_s32(acc[j]);
+ }
+ const int16x8_t res_s16_0 = vcombine_s16(acc_s16[0], acc_s16[1]);
+ const int16x8_t res_s16_1 = vcombine_s16(acc_s16[2], acc_s16[3]);
+ const uint8x8_t res_u8_0 = vqmovun_s16(res_s16_0);
+ const uint8x8_t res_u8_1 = vqmovun_s16(res_s16_1);
+ vst1q_u8(output_ptr, vcombine_u8(res_u8_0, res_u8_1));
+ output_ptr += 16;
+ }
+ // Handle 8 values at once.
+ // Not as good as 16 (now we're only issuing 2 mutually independent
+ // vqrdmulh instructions, so we're probably paying for their high
+ // latency).
+ for (; i <= num_output_values - 8; i += 8) {
+ int32x4_t acc0 = vld1q_s32(acc_buffer + i);
+ int32x4_t acc1 = vld1q_s32(acc_buffer + i + 4);
+ // Fixed-point multiplication.
+ acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
+ acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
+ // Rounding right shift.
+ acc0 = RoundingDivideByPOT(acc0, output_shift);
+ acc1 = RoundingDivideByPOT(acc1, output_shift);
+ // Add the output offset.
+ acc0 = vaddq_s32(acc0, output_offset_vec);
+ acc1 = vaddq_s32(acc1, output_offset_vec);
+ // Apply the activation function.
+ acc0 = vmaxq_s32(acc0, output_activation_min_vec);
+ acc1 = vmaxq_s32(acc1, output_activation_min_vec);
+ acc0 = vminq_s32(acc0, output_activation_max_vec);
+ acc1 = vminq_s32(acc1, output_activation_max_vec);
+ // Saturating cast to uint8 and store to destination.
+ const int16x4_t acc0_s16 = vqmovn_s32(acc0);
+ const int16x4_t acc1_s16 = vqmovn_s32(acc1);
+ const int16x8_t res_s16 = vcombine_s16(acc0_s16, acc1_s16);
+ const uint8x8_t res_u8 = vqmovun_s16(res_s16);
+ vst1_u8(output_ptr, res_u8);
+ output_ptr += 8;
+ }
+ // Handle 4 values at once. Now we're paying the full price of the
+ // high latency of vqrdmulh. Also, storing only 4 bytes at the end
+ // (without any alignment) can only be done 1 byte at a time.
+ // Yet, that is still worth doing to minimize the amount of leftover
+ // that will have to go through the very slow scalar code.
+ for (; i <= num_output_values - 4; i += 4) {
+ int32x4_t acc = vld1q_s32(acc_buffer + i);
+ // Fixed-point multiplication.
+ acc = vqrdmulhq_n_s32(acc, output_multiplier);
+ // Rounding right shift.
+ acc = RoundingDivideByPOT(acc, output_shift);
+ // Add the output offset.
+ acc = vaddq_s32(acc, output_offset_vec);
+ // Apply the activation function.
+ acc = vmaxq_s32(acc, output_activation_min_vec);
+ acc = vminq_s32(acc, output_activation_max_vec);
+ // Saturating cast to uint8 and store to destination.
+ const int16x4_t acc_s16 = vqmovn_s32(acc);
+ const int16x8_t res_s16 = vcombine_s16(acc_s16, acc_s16);
+ const uint8x8_t res_u8 = vqmovun_s16(res_s16);
+ vst1_lane_u8(output_ptr + 0, res_u8, 0);
+ vst1_lane_u8(output_ptr + 1, res_u8, 1);
+ vst1_lane_u8(output_ptr + 2, res_u8, 2);
+ vst1_lane_u8(output_ptr + 3, res_u8, 3);
+ output_ptr += 4;
+ }
+#endif // USE_NEON
+
+ // Handle leftover values, one by one. This is very slow.
+ for (; i < num_output_values; i++) {
+ int32 acc = acc_buffer[i];
+ acc = MultiplyByQuantizedMultiplierSmallerThanOne(
+ acc, output_multiplier, output_shift);
+ acc += output_offset;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ *output_ptr++ = static_cast<uint8>(acc);
+ }
+ }
+ }
+ }
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
+ filter_dims, filter_offset, bias_data, bias_dims, stride,
+ stride, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+} // namespace optimized_ops
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h
new file mode 100644
index 0000000000..8004c24a99
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h
@@ -0,0 +1,231 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Copied from tensorflow/core/kernels/eigen_spatial_convolutions.h.
+// TODO(petewarden) - move this to a common location in Eigen itself.
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_
+
+#define EIGEN_USE_CUSTOM_THREAD_POOL
+#define EIGEN_USE_THREADS
+
+// NOTE: Eigen is slightly different internally and externally. We need to
+// hack the unsupported/Eigen/CXX11/Tensor header instantiation macros at
+// specific places, so we need two copies of the hacked file, one for
+// internal and one for external.
+// If you have trouble simply undef out the reducer macro e.g.
+// TFLITE_REDUCE_INSTANTIATIONS_GOOGLE, but be aware this will make
+// the binary much bigger!
+#define TFLITE_REDUCE_INSTANTIATIONS_OPEN_SOURCE
+#define Eigen EigenForTFLite
+#if defined(TFLITE_REDUCE_INSTANTIATIONS_GOOGLE)
+#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h"
+#elif defined(TFLITE_REDUCE_INSTANTIATIONS_OPEN_SOURCE)
+#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h"
+#else
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#endif
+
+
+namespace Eigen {
+
+/** SpatialConvolution
+ * \ingroup CXX11_NeuralNetworks_Module
+ *
+ * \brief Applies a 2D convolution over a multichannel input image.
+ *
+ * The input parameter is expected to be a tensor with a rank of 3 or more
+ * (channels, height, width, and optionally others)
+ * The kernel parameter is expected to be a 4D tensor (filters, channels,
+ * kernel_height, kernel_width)
+ * The input and the kernel must both be in col-major layout. The result will
+ * also be in col-major layout.
+ *
+ * If col_in_stride, row_in_stride > 1, then applies convolution with holes
+ * (aka atrous convolution), sampling every col_in_stride, row_in_stride input
+ * pixels.
+ *
+ * The result can be assigned to a tensor of rank equal to the rank of the
+ * input. The dimensions of the result will be filters, height, width (and
+ * others if applicable).
+ *
+ * It is possible to swap the order of the width and height dimensions provided
+ * that the same order is used in the input, the kernel, and the output.
+ *
+ */
+template <typename Input, typename Kernel>
+EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
+ internal::traits<Input>::Layout == ColMajor,
+ TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const TensorContractionOp<
+ const array<IndexPair<typename internal::traits<Input>::Index>,
+ 1>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const Kernel>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic,
+ const Input> > > >,
+ TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const TensorContractionOp<
+ const array<IndexPair<typename internal::traits<Input>::Index>,
+ 1>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const Kernel> > > >::type
+ SpatialConvolution(const Input& input, const Kernel& kernel,
+ const DenseIndex row_stride = 1,
+ const DenseIndex col_stride = 1,
+ const PaddingType padding_type = PADDING_SAME,
+ const DenseIndex row_in_stride = 1,
+ const DenseIndex col_in_stride = 1) {
+ typedef typename internal::traits<Input>::Index TensorIndex;
+ TensorRef<Tensor<typename internal::traits<Input>::Scalar,
+ internal::traits<Input>::NumDimensions,
+ internal::traits<Input>::Layout, TensorIndex> >
+ in(input);
+ TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
+ internal::traits<Kernel>::NumDimensions,
+ internal::traits<Kernel>::Layout, TensorIndex> >
+ kern(kernel);
+
+ EIGEN_STATIC_ASSERT(
+ internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
+ YOU_MADE_A_PROGRAMMING_MISTAKE);
+ const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
+
+ const int NumDims = internal::traits<Input>::NumDimensions;
+
+ // Number of filters to apply. This is the same as the output depth of the
+ // result
+ const TensorIndex kernelFilters =
+ isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
+ // Number of channels. This is the same as the input depth.
+ const TensorIndex kernelChannels =
+ isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
+ const TensorIndex kernelRows =
+ isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
+ const TensorIndex kernelCols =
+ isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
+
+ const DenseIndex kernelRowsEff =
+ kernelRows + (kernelRows - 1) * (row_in_stride - 1);
+ const DenseIndex kernelColsEff =
+ kernelCols + (kernelCols - 1) * (col_in_stride - 1);
+
+ array<IndexPair<TensorIndex>, 1> contract_dims;
+ contract_dims[0] = IndexPair<TensorIndex>(1, 0);
+
+ const TensorIndex InputRows =
+ isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
+ const TensorIndex InputCols =
+ isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
+
+ TensorIndex out_height;
+ TensorIndex out_width;
+ switch (padding_type) {
+ case PADDING_VALID:
+ out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) /
+ static_cast<float>(row_stride));
+ out_width = numext::ceil((InputCols - kernelColsEff + 1.f) /
+ static_cast<float>(col_stride));
+ break;
+ case PADDING_SAME:
+ out_height = numext::ceil(InputRows / static_cast<float>(row_stride));
+ out_width = numext::ceil(InputCols / static_cast<float>(col_stride));
+ break;
+ default:
+ // Initialize unused variables to avoid a compiler warning
+ out_height = 0;
+ out_width = 0;
+ eigen_assert(false && "unexpected padding");
+ }
+
+ // Molds the output of the patch extraction code into a 2d tensor:
+ // - the first dimension (dims[0]): the patch values to be multiplied with the
+ // kernels
+ // - the second dimension (dims[1]): everything else
+ DSizes<TensorIndex, 2> pre_contract_dims;
+ if (isColMajor) {
+ pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
+ pre_contract_dims[1] = out_height * out_width;
+ for (int i = 3; i < NumDims; ++i) {
+ pre_contract_dims[1] *= in.dimension(i);
+ }
+ } else {
+ pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols;
+ pre_contract_dims[0] = out_height * out_width;
+ for (int i = 0; i < NumDims - 3; ++i) {
+ pre_contract_dims[0] *= in.dimension(i);
+ }
+ }
+
+ // Molds the output of the contraction into the shape expected by the used
+ // (assuming this is ColMajor):
+ // - 1st dim: kernel filters
+ // - 2nd dim: output height
+ // - 3rd dim: output width
+ // - 4th dim and beyond: everything else including batch size
+ DSizes<TensorIndex, NumDims> post_contract_dims;
+ if (isColMajor) {
+ post_contract_dims[0] = kernelFilters;
+ post_contract_dims[1] = out_height;
+ post_contract_dims[2] = out_width;
+ for (int i = 3; i < NumDims; ++i) {
+ post_contract_dims[i] = in.dimension(i);
+ }
+ } else {
+ post_contract_dims[NumDims - 1] = kernelFilters;
+ post_contract_dims[NumDims - 2] = out_height;
+ post_contract_dims[NumDims - 3] = out_width;
+ for (int i = 0; i < NumDims - 3; ++i) {
+ post_contract_dims[i] = in.dimension(i);
+ }
+ }
+
+ DSizes<TensorIndex, 2> kernel_dims;
+ if (isColMajor) {
+ kernel_dims[0] = kernelFilters;
+ kernel_dims[1] = kernelChannels * kernelRows * kernelCols;
+ } else {
+ kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
+ kernel_dims[1] = kernelFilters;
+ }
+ // TODO(yangke): choose() is defined in TensorContraction.h -- consider
+ // moving it to somewhere more "common".
+ return
+ input
+ .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
+ row_in_stride, col_in_stride, padding_type)
+ .reshape(pre_contract_dims)
+ .contract(kernel.reshape(kernel_dims), contract_dims)
+ .reshape(post_contract_dims);
+}
+
+} // end namespace Eigen
+
+// clang-format on
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
new file mode 100644
index 0000000000..7f78f69360
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
@@ -0,0 +1,143 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
+
+#define EIGEN_USE_CUSTOM_THREAD_POOL
+#define EIGEN_USE_THREADS
+
+// clang-format off
+
+#include <stdint.h>
+
+#include <cstddef>
+#include <cstring>
+#include <cmath>
+#include <random>
+#include <atomic>
+#include <condition_variable> // NOLINT(build/c++11)
+#include <mutex> // NOLINT(build/c++11)
+#include <thread> // NOLINT(build/c++11)
+#include <functional>
+
+#ifdef _WIN32
+#include <winbase.h>
+#elif defined(__APPLE__)
+#include <mach/mach_time.h>
+#else
+#include <time.h>
+#endif
+
+
+// Because some programs may link Eigen in through other frameworks with
+// different flags, we can run into multiple definition issues if we don't have
+// a private namespace for our versions. This is a nasty hack, but a similar
+// approach is used elsewhere to handle the problem, so it should be stable.
+#define Eigen EigenForTFLite
+
+#include "Eigen/src/Core/util/StaticAssert.h"
+#include "unsupported/Eigen/CXX11/Core"
+#include "unsupported/Eigen/SpecialFunctions"
+
+#include "Eigen/src/Core/util/DisableStupidWarnings.h"
+
+#include "Eigen/Core"
+
+// Beware: the order of the include matters to some compilers. For example
+// TensorIndexList.h should be included before TensorDimensions.h in order to
+// use index lists to encode tensor dimensions when compiling with llvm.
+// We're defining this ourselves rather than using the Eigen Tensor header file
+// so that we can alter the macro definition of TENSOR_CONTRACTION_DISPATCH to
+// reduce binary size.
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/ThreadPoolInterface.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorNonBlockingThreadPool.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStats.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMappers.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h"
+#undef TENSOR_CONTRACTION_DISPATCH
+#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
+ if (this->m_lhs_inner_dim_contiguous && \
+ this->m_rhs_inner_dim_contiguous && \
+ !this->m_rhs_inner_dim_reordered) { \
+ METHOD<true, true, false, ALIGNMENT> ARGS; \
+ } else { \
+ eigen_assert(false && "Unsupported contraction formats"); \
+ }
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/Tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h"
+
+#include "Eigen/src/Core/util/ReenableStupidWarnings.h"
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_H
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h
new file mode 100644
index 0000000000..1d5c316194
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h
@@ -0,0 +1,167 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This is essentially unsupported/CXX11/Eigen/Tensor.h
+// TODO(petewarden) - move this to a common location in Eigen itself.
+
+// clang-format off
+
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_
+
+
+#include "Eigen/Core"
+
+#if defined(EIGEN_USE_SYCL)
+#undef min
+#undef max
+#undef isnan
+#undef isinf
+#undef isfinite
+#include <CL/sycl.hpp>
+#include <iostream>
+#include <map>
+#include <memory>
+#include <utility>
+#endif
+#include <cmath>
+#include <cstddef>
+#include <cstring>
+
+
+
+
+
+#ifdef _WIN32
+typedef __int16 int16_t;
+typedef unsigned __int16 uint16_t;
+typedef __int32 int32_t;
+typedef unsigned __int32 uint32_t;
+typedef __int64 int64_t;
+typedef unsigned __int64 uint64_t;
+#include <windows.h>
+#else
+#include <stdint.h>
+#include <unistd.h>
+#endif
+
+#if __cplusplus > 199711 || EIGEN_COMP_MSVC >= 1900
+#include <random>
+#endif
+
+#ifdef _WIN32
+#include <windows.h>
+#elif defined(__APPLE__)
+#include <mach/mach_time.h>
+#else
+#include <time.h>
+#endif
+
+// #if defined(EIGEN_USE_LIBXSMM)
+// #include "libxsmm.h"
+// #endif
+
+#ifdef EIGEN_USE_THREADS
+#include "unsupported/Eigen/CXX11/ThreadPool"
+#endif
+
+
+#include "Eigen/src/Core/util/DisableStupidWarnings.h"
+
+#include "unsupported/Eigen/SpecialFunctions"
+#include "unsupported/Eigen/CXX11/src/util/CXX11Meta.h"
+#include "unsupported/Eigen/CXX11/src/util/MaxSizeVector.h"
+
+
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorMacros.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h"
+
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorBase.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h"
+
+#undef TENSOR_CONTRACTION_DISPATCH
+#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
+ if (this->m_lhs_inner_dim_contiguous && \
+ this->m_rhs_inner_dim_contiguous && \
+ !this->m_rhs_inner_dim_reordered) { \
+ METHOD<true, true, false, ALIGNMENT> ARGS; \
+ } else { \
+ eigen_assert(false && "Unsupported contraction formats"); \
+ }
+
+
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorScan.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/Tensor.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorMap.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorRef.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorIO.h"
+
+#include "Eigen/src/Core/util/ReenableStupidWarnings.h"
+
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
new file mode 100644
index 0000000000..b3615f4658
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -0,0 +1,195 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
+
+#include <assert.h>
+#include <stdint.h>
+#include <sys/types.h>
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <memory>
+#include <tuple>
+#include <type_traits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace multithreaded_ops {
+
+class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
+ public:
+ explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {}
+ ~EigenThreadPoolWrapper() override {}
+
+ void Schedule(std::function<void()> fn) override {
+ pool_->Schedule(std::move(fn));
+ }
+ int NumThreads() const override { return pool_->NumThreads(); }
+ int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
+
+ private:
+ Eigen::ThreadPool* pool_ = nullptr;
+};
+
+// We have a single global threadpool for all convolution operations. This means
+// that inferences started from different threads may block each other, but
+// since the underlying resource of CPU cores should be consumed by the
+// operations anyway, it shouldn't affect overall performance.
+const Eigen::ThreadPoolDevice& GetThreadPoolDevice() {
+ const int thread_count = 4;
+ static Eigen::ThreadPool* tp = new Eigen::ThreadPool(thread_count);
+ static EigenThreadPoolWrapper* thread_pool_wrapper =
+ new EigenThreadPoolWrapper(tp);
+ static Eigen::ThreadPoolDevice* device =
+ new Eigen::ThreadPoolDevice(thread_pool_wrapper, thread_count);
+ return *device;
+}
+
+// Shorthands for the types we need when interfacing with the EigenTensor
+// library.
+typedef Eigen::TensorMap<
+ Eigen::Tensor<float, 2, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>
+ EigenMatrix;
+typedef Eigen::TensorMap<
+ Eigen::Tensor<const float, 2, Eigen::RowMajor, Eigen::DenseIndex>,
+ Eigen::Aligned>
+ ConstEigenMatrix;
+
+typedef Eigen::TensorMap<
+ Eigen::Tensor<float, 4, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>
+ EigenTensor;
+typedef Eigen::TensorMap<
+ Eigen::Tensor<const float, 4, Eigen::RowMajor, Eigen::DenseIndex>,
+ Eigen::Aligned>
+ ConstEigenTensor;
+
+// Utility functions we need for the EigenTensor API.
+template <typename Device, typename T>
+struct MatMulConvFunctor {
+ // Computes on device "d": out = in0 * in1, where * is matrix
+ // multiplication.
+ void operator()(
+ const Device& d, EigenMatrix out, ConstEigenMatrix in0,
+ ConstEigenMatrix in1,
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
+ out.device(d) = in0.contract(in1, dim_pair);
+ }
+};
+
+template <class T>
+class EigenTensorConvFunctor {
+ private:
+ Eigen::PaddingType TfLitePadding2EigenPadding(TfLitePadding padding) {
+ switch (padding) {
+ case kTfLitePaddingValid:
+ return Eigen::PADDING_VALID;
+ case kTfLitePaddingSame:
+ return Eigen::PADDING_SAME;
+ case kTfLitePaddingUnknown:
+ assert(false); // should never get here.
+ return Eigen::PADDING_VALID;
+ }
+ return Eigen::PADDING_SAME; // Prevent compiler warning about missing
+ // return
+ }
+
+ public:
+ void operator()(const T* input_data, T* im2col_buffer, int input_batches,
+ int input_height, int input_width, int input_depth,
+ const T* filter_data, int filter_height, int filter_width,
+ int filter_count, int stride_rows, int stride_cols,
+ int pad_width, int pad_height, TfLitePadding padding,
+ T* output_data, int output_height, int output_width) {
+ const Eigen::ThreadPoolDevice& device = GetThreadPoolDevice();
+
+ const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 &&
+ stride_rows == 1 && stride_cols == 1);
+ if (is_1x1_kernel) {
+ // For 1x1 kernel, the 2D convolution is reduced to matrix
+ // multiplication.
+ const int conv_width = output_height * output_width;
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
+ dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
+ EigenMatrix output(output_data, conv_width, filter_count);
+ ConstEigenMatrix input(input_data, conv_width, input_depth);
+ ConstEigenMatrix filter(filter_data, input_depth, filter_count);
+ MatMulConvFunctor<Eigen::ThreadPoolDevice, T>()(device, output, input,
+ filter, dim_pair);
+ } else if (filter_height == input_height && filter_width == input_width &&
+ pad_width == 0 && pad_height == 0) {
+ // If the input data and filter have the same height/width,
+ // the 2D convolution is reduced to matrix multiplication.
+ const int k = // Length of reduction dimension.
+ filter_width * filter_height * input_depth;
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
+ dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
+ EigenMatrix output(output_data, 1, filter_count);
+ ConstEigenMatrix input(input_data, 1, k);
+ ConstEigenMatrix filter(filter_data, k, filter_count);
+ MatMulConvFunctor<Eigen::ThreadPoolDevice, T>()(device, output, input,
+ filter, dim_pair);
+ } else {
+ EigenTensor output(output_data, input_batches, output_height,
+ output_width, filter_count);
+ ConstEigenTensor input(input_data, input_batches, input_height,
+ input_width, input_depth);
+ ConstEigenTensor filter(filter_data, filter_height, filter_width,
+ input_depth, filter_count);
+ output.device(device) =
+ Eigen::SpatialConvolution(input, filter, stride_cols, stride_rows,
+ TfLitePadding2EigenPadding(padding));
+ }
+ }
+};
+
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, TfLitePadding padding,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ float* im2col_data, const Dims<4>& im2col_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
+ const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ EigenTensorConvFunctor<float> conv_functor;
+ conv_functor(input_data, im2col_data, batches, input_height, input_width,
+ input_depth, filter_data, filter_height, filter_width,
+ output_depth, stride_height, stride_width, pad_height, pad_width,
+ padding, output_data, output_height, output_width);
+
+ optimized_ops::AddBiasAndEvalActivationFunction(
+ bias_data, bias_dims, output_data, output_dims, output_activation_min,
+ output_activation_max);
+}
+
+} // namespace multithreaded_ops
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
new file mode 100644
index 0000000000..bf0bdfb1fb
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -0,0 +1,337 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h"
+
+#ifdef USE_NEON
+
+#include <arm_neon.h>
+#define kFloatWeightsPerNeonLane 4
+
+namespace tflite {
+namespace tensor_utils {
+
+void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
+ int m_cols, const float* vector,
+ int n_batch, float* result,
+ int result_stride) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1));
+
+ // The arrays used to cache the vector.
+ float32x4_t* vector_cache_float32x4 =
+ new float32x4_t[(m_cols / kFloatWeightsPerNeonLane) *
+ sizeof(float32x4_t)];
+ const int kUnrollSize = 2;
+ for (int b = 0; b < n_batch; b++) {
+ float* result_in_batch = result + b * m_rows * result_stride;
+ const float* vector_in_batch = vector + b * m_cols;
+
+ const float* matrix_ptr0 = matrix;
+ // If there is only 1 row, we don't want to assign an illegal pointer.
+ const float* matrix_ptr1 = nullptr;
+ if (m_rows > 1) {
+ matrix_ptr1 = matrix + m_cols;
+ }
+
+ // Cahce the vector.
+ for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
+ vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c);
+ }
+
+ // Main matrix by vector multiplication loop, which handles two rows of
+ // matrix by vector multiplication.
+ for (int r = 0; r < (m_rows & ~(kUnrollSize - 1)); r += kUnrollSize) {
+ float32x4_t acc0_32x4 = vmovq_n_f32(0.0);
+ float32x4_t acc1_32x4 = vmovq_n_f32(0.0);
+ for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
+ float32x4_t temp = vector_cache_float32x4[c >> 2];
+ // Load 4 float values from vector1 and vector2 and accumulator.
+ float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c);
+ float32x4_t v1_f32x4 = vld1q_f32(matrix_ptr1 + c);
+ // Vector multiply-accumulate 4 float
+ acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp);
+ acc1_32x4 = vmlaq_f32(acc1_32x4, v1_f32x4, temp);
+ }
+ // Add the 4 intermediate sum values to get the final dot-prod value for
+ // this column.
+ *result_in_batch +=
+ (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) +
+ vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3));
+ *(result_in_batch + result_stride) +=
+ (vgetq_lane_f32(acc1_32x4, 0) + vgetq_lane_f32(acc1_32x4, 1) +
+ vgetq_lane_f32(acc1_32x4, 2) + vgetq_lane_f32(acc1_32x4, 3));
+ for (int c = postamble_start; c < m_cols; c++) {
+ *result_in_batch += matrix_ptr0[c] * vector_in_batch[c];
+ *(result_in_batch + result_stride) +=
+ matrix_ptr1[c] * vector_in_batch[c];
+ }
+ matrix_ptr0 += kUnrollSize * m_cols;
+ matrix_ptr1 += kUnrollSize * m_cols;
+ result_in_batch += kUnrollSize * result_stride;
+ }
+ for (int r = (m_rows & ~(kUnrollSize - 1)); r < m_rows; r++) {
+ float32x4_t acc0_32x4 = vmovq_n_f32(0.0);
+ for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
+ float32x4_t temp = vector_cache_float32x4[c >> 2];
+ // Load 4 float values from vector1 and vector2 and accumulator.
+ float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c);
+ // Vector multiply-accumulate 4 float
+ acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp);
+ }
+ // Add the 4 intermediate sum values to get the final dot-prod value for
+ // this column.
+ *result_in_batch +=
+ (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) +
+ vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3));
+ for (int c = postamble_start; c < m_cols; c++) {
+ *result_in_batch += matrix_ptr0[c] * vector_in_batch[c];
+ }
+ matrix_ptr0 += m_cols;
+ result_in_batch += result_stride;
+ }
+ }
+ delete[] vector_cache_float32x4;
+}
+
+void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2,
+ int v_size, float* result) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load 4 float values from vector1 and vector2.
+ float32x4_t v1_f32x4 = vld1q_f32(vector1 + v);
+ float32x4_t v2_f32x4 = vld1q_f32(vector2 + v);
+ // Vector multiply 4 float
+ float32x4_t mul_32x4 = vmulq_f32(v1_f32x4, v2_f32x4);
+ // Save to result array.
+ vst1q_f32(&result[v], mul_32x4);
+ }
+ for (int v = postamble_start; v < v_size; v++) {
+ result[v] = vector1[v] * vector2[v];
+ }
+}
+
+void NeonVectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2, int v_size,
+ float* result) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load 4 float values from vector1 and vector2 and accumulator.
+ float32x4_t v1_f32x4 = vld1q_f32(vector1 + v);
+ float32x4_t v2_f32x4 = vld1q_f32(vector2 + v);
+ float32x4_t acc_32x4 = vld1q_f32(result + v);
+ // Vector multiply-accumulate 4 float
+ acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4);
+ // Save to result array.
+ vst1q_f32(&result[v], acc_32x4);
+ }
+ for (int v = postamble_start; v < v_size; v++) {
+ result[v] += vector1[v] * vector2[v];
+ }
+}
+
+void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
+ int v_size,
+ const float* batch_vector,
+ int n_batch, float* result) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+
+ // The arrays used to cache the vector.
+ float32x4_t* vector_cache_float32x4 =
+ new float32x4_t[(v_size / kFloatWeightsPerNeonLane) *
+ sizeof(float32x4_t)];
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ vector_cache_float32x4[v >> 2] = vld1q_f32(vector + v);
+ }
+
+ float* result_ptr = result;
+ const float* batch_vector_ptr = batch_vector;
+ for (int b = 0; b < n_batch; b++) {
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load from memory to vectors.
+ float32x4_t result_f32x4 = vld1q_f32(result_ptr + v);
+ float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector_ptr + v);
+ // Multiply-accumulate.
+ result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4,
+ vector_cache_float32x4[v >> 2]);
+ // Store.
+ vst1q_f32(result_ptr + v, result_f32x4);
+ }
+ // Postamble loop
+ for (int v = postamble_start; v < v_size; v++) {
+ result_ptr[v] += vector[v] * batch_vector_ptr[v];
+ }
+ // Update the pointers.
+ result_ptr += v_size;
+ batch_vector_ptr += v_size;
+ }
+ delete[] vector_cache_float32x4;
+}
+
+void NeonSub1Vector(const float* vector, int v_size, float* result) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+
+ float32x4_t one_f32x4 = vmovq_n_f32(1.0);
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load 4 float values from the current pointers of the input column and
+ // subtract from 1.
+ float32x4_t v_f32x4 = vld1q_f32(vector + v);
+ float32x4_t result_f32x4 = vsubq_f32(one_f32x4, v_f32x4);
+ // Save to output.
+ vst1q_f32(result + v, result_f32x4);
+ }
+ for (int v = postamble_start; v < v_size; v++) {
+ result[v] = 1.0f - vector[v];
+ }
+}
+
+void NeonClipVector(const float* vector, int v_size, float abs_limit,
+ float* result) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+
+ // Replicate abs_limit and -abs_limit in two vectors.
+ const float32x4_t abs_limit_f32x4 = vmovq_n_f32(abs_limit);
+ const float32x4_t neg_abs_limit_f32x4 = vmovq_n_f32(-abs_limit);
+
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load from memory to vector.
+ float32x4_t v_f32x4 = vld1q_f32(vector + v);
+ // Clip between abs_limit and -abs_limit.
+ float32x4_t result_f32x4 = vminq_f32(abs_limit_f32x4, v_f32x4);
+ result_f32x4 = vmaxq_f32(neg_abs_limit_f32x4, result_f32x4);
+ // Save to output.
+ vst1q_f32(result + v, result_f32x4);
+ }
+ // Postamble loop.
+ for (int v = postamble_start; v < v_size; v++) {
+ result[v] = (abs_limit < vector[v]) ? abs_limit : vector[v];
+ result[v] = (-abs_limit > result[v]) ? -abs_limit : result[v];
+ }
+}
+
+float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+ float32x4_t acc_32x4 = vmovq_n_f32(0.0);
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load 4 float values from vector1 and vector2 and accumulator.
+ float32x4_t v1_f32x4 = vld1q_f32(vector1 + v);
+ float32x4_t v2_f32x4 = vld1q_f32(vector2 + v);
+ // Vector multiply-accumulate 4 float
+ acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4);
+ }
+
+ float result = (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) +
+ vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3));
+ // Postamble loop.
+ for (int v = postamble_start; v < v_size; v++) {
+ result += vector1[v] * vector2[v];
+ }
+ return result;
+}
+
+void NeonBatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride) {
+ float* result_ptr = result;
+ const float* vector1_ptr = vector1;
+ const float* vector2_ptr = vector2;
+ for (int b = 0; b < n_batch; b++) {
+ *result_ptr = NeonVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size);
+ vector1_ptr += v_size;
+ vector2_ptr += v_size;
+ result_ptr += result_stride;
+ }
+}
+
+void NeonReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size) {
+ const float* input_vector_ptr = input_vector;
+ for (int o = 0; o < output_size; o++) {
+ // If reduction_size is not divisible by kWeightsPerNeonLane, we cannot use
+ // the main vectorized loop, and we need to process sequentially.
+ // postamble_start shows the start index where this should happen.
+ const int postamble_start =
+ reduction_size - (reduction_size & (kFloatWeightsPerNeonLane - 1));
+ float32x4_t sum_f32x4 = vmovq_n_f32(0.0);
+ for (int r = 0; r < postamble_start; r += kFloatWeightsPerNeonLane) {
+ float32x4_t v1_f32x4 = vld1q_f32(input_vector_ptr + r);
+ sum_f32x4 = vaddq_f32(sum_f32x4, v1_f32x4);
+ }
+ output_vector[o] +=
+ (vgetq_lane_f32(sum_f32x4, 0) + vgetq_lane_f32(sum_f32x4, 1) +
+ vgetq_lane_f32(sum_f32x4, 2) + vgetq_lane_f32(sum_f32x4, 3));
+ input_vector_ptr += postamble_start;
+
+ // Postamble loop.
+ for (int r = postamble_start; r < reduction_size; r++) {
+ output_vector[o] += *input_vector_ptr++;
+ }
+ }
+}
+
+void NeonVectorShiftLeft(float* vector, int v_size, float shift_value) {
+ // This variable keeps track of the next to the last index which is being
+ // copied to make sure we are not out of the vector boundary.
+ int last_index_copy = kFloatWeightsPerNeonLane;
+ int current_index_copy = 0;
+ while (last_index_copy < v_size) {
+ float32x4_t v_f32x4 = vld1q_f32(vector + current_index_copy + 1);
+ vst1q_f32(vector + current_index_copy, v_f32x4);
+ current_index_copy += kFloatWeightsPerNeonLane;
+ last_index_copy += kFloatWeightsPerNeonLane;
+ }
+ // Postamble loop.
+ for (int i = current_index_copy; i < v_size - 1; i++) {
+ vector[i] = vector[i + 1];
+ }
+ vector[v_size - 1] = shift_value;
+}
+
+} // namespace tensor_utils
+} // namespace tflite
+
+#endif // USE_NEON
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
new file mode 100644
index 0000000000..3a4af87304
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -0,0 +1,113 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
+
+// TODO(ghodrat): Remove this header file and the dependency to internal data
+// structure.
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h"
+
+namespace tflite {
+namespace tensor_utils {
+
+void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
+ int m_cols, const float* vector,
+ int n_batch, float* result,
+ int result_stride) {
+ NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
+ vector, n_batch, result, result_stride);
+}
+
+void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
+ int v_size, float* result) {
+ NEON_OR_PORTABLE(VectorVectorCwiseProduct, vector1, vector2, v_size, result);
+}
+
+void VectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2, int v_size,
+ float* result) {
+ NEON_OR_PORTABLE(VectorVectorCwiseProductAccumulate, vector1, vector2, v_size,
+ result);
+}
+
+void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result) {
+ NEON_OR_PORTABLE(VectorBatchVectorCwiseProductAccumulate, vector, v_size,
+ batch_vector, n_batch, result);
+}
+
+float VectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size) {
+ return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size);
+}
+
+void BatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride) {
+ NEON_OR_PORTABLE(BatchVectorBatchVectorDotProduct, vector1, vector2, v_size,
+ n_batch, result, result_stride);
+}
+
+void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
+}
+
+void ApplySigmoidToVector(const float* vector, int v_size, float* result) {
+ PortableApplySigmoidToVector(vector, v_size, result);
+}
+
+void ApplyActivationToVector(const float* vector, int v_size,
+ TfLiteFusedActivation activation, float* result) {
+ PortableApplyActivationToVector(vector, v_size, activation, result);
+}
+
+void CopyVector(const float* vector, int v_size, float* result) {
+ PortableCopyVector(vector, v_size, result);
+}
+
+void Sub1Vector(const float* vector, int v_size, float* result) {
+ NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result);
+}
+
+void ZeroVector(float* vector, int v_size) {
+ PortableZeroVector(vector, v_size);
+}
+
+float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
+
+void ClipVector(const float* vector, int v_size, float abs_limit,
+ float* result) {
+ NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result);
+}
+
+void VectorShiftLeft(float* vector, int v_size, float shift_value) {
+ NEON_OR_PORTABLE(VectorShiftLeft, vector, v_size, shift_value);
+}
+
+void ReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size) {
+ NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size,
+ reduction_size);
+}
+
+} // namespace tensor_utils
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
new file mode 100644
index 0000000000..cd565c16a1
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -0,0 +1,3715 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
+
+#include <assert.h>
+#include <stdint.h>
+#include <sys/types.h>
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <memory>
+#include <tuple>
+#include <type_traits>
+
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "fixedpoint/fixedpoint.h"
+#include "public/gemmlowp.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace optimized_ops {
+
+// Make a local VectorMap typedef allowing to map a float array
+// as a Eigen vector expression. The std::conditional here is to
+// construct the suitable Eigen type for the constness of the
+// data. Indeed, for const data, we need to produce
+// Eigen::Map<const Eigen::Matrix<float, ...>>
+// and not the more straightforward
+// Eigen::Map<Eigen::Matrix<const float, ...>>
+template <typename Scalar>
+using VectorMap = typename std::conditional<
+ std::is_const<Scalar>::value,
+ Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
+ Eigen::Dynamic, 1>>,
+ Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, 1>>>::type;
+
+template <typename Scalar, int N>
+VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
+ const int size = RequiredBufferSizeForDims(dims);
+ return VectorMap<Scalar>(data, size, 1);
+}
+
+// Make a local VectorMap typedef allowing to map a float array
+// as a Eigen matrix expression. The same explanation as for VectorMap
+// above also applies here.
+template <typename Scalar>
+using MatrixMap = typename std::conditional<
+ std::is_const<Scalar>::value,
+ Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
+ Eigen::Dynamic, Eigen::Dynamic>>,
+ Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
+
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
+ const Dims<N>& dims) {
+ const int rows = dims.sizes[0];
+ int cols = 1;
+ for (int d = 1; d < N; d++) {
+ cols *= dims.sizes[d];
+ }
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
+ const Dims<N>& dims) {
+ const int cols = dims.sizes[N - 1];
+ int rows = 1;
+ for (int d = 0; d < N - 1; d++) {
+ rows *= dims.sizes[d];
+ }
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar>
+using ArrayMap = typename std::conditional<
+ std::is_const<Scalar>::value,
+ Eigen::Map<const Eigen::Array<typename std::remove_const<Scalar>::type,
+ Eigen::Dynamic, Eigen::Dynamic>>,
+ Eigen::Map<Eigen::Array<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
+
+template <typename Scalar, int N>
+ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
+ const Dims<N>& dims) {
+ const int rows = dims.sizes[0];
+ int cols = 1;
+ for (int d = 1; d < N; d++) {
+ cols *= dims.sizes[d];
+ }
+ return ArrayMap<Scalar>(data, rows, cols);
+}
+
+// TODO(b/62193649): this function is only needed as long
+// as we have the --variable_batch hack.
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
+ const Dims<N>& dims,
+ int rows) {
+ int cols = 1;
+ bool matched_rows = false;
+ for (int d = 0; d < N; d++) {
+ cols *= dims.sizes[d];
+ if (cols == rows) {
+ matched_rows = true;
+ cols = 1;
+ }
+ }
+ TFLITE_DCHECK(matched_rows);
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
+// BROADCASTING.
+//
+// NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
+// rectangular array of numbers.
+//
+// NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
+// However, as Dims<N> is to be deprecated, this class exists as an adaptor
+// to enable simple unoptimized implementations of element-wise broadcasting
+// operations.
+template <int N>
+struct NdArrayDesc {
+ // The "extent" of each dimension. Indices along dimension d must be in the
+ // half-open interval [0, extents[d]).
+ int extents[N];
+
+ // The number of *elements* (not bytes) between consecutive indices of each
+ // dimension.
+ int strides[N];
+};
+
+// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
+// ELEMENT-WISE BROADCASTING.
+//
+// Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
+inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
+ int i3) {
+ TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
+ TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
+ TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
+ TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
+ return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
+ i3 * desc.strides[3];
+}
+
+// Given the dimensions of the operands for an element-wise binary broadcast,
+// adjusts them so that they can be directly iterated over with simple loops.
+// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
+// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
+//
+// This function assumes that the two input shapes are compatible up to
+// broadcasting and the shorter one has already been prepended with 1s to be the
+// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
+// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
+// Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
+// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
+//
+// When two shapes are compatible up to broadcasting, for each dimension d,
+// the input extents are either equal, or one of them is 1.
+//
+// This function performs the following for each dimension d:
+// - If the extents are equal, then do nothing since the loop that walks over
+// both of the input arrays is correct.
+// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
+// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
+// array0 to be referenced *at any index* in dimension d and still access the
+// same slice.
+template <int N>
+inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
+ const Dims<N>& input1_dims,
+ NdArrayDesc<N>* desc0_out,
+ NdArrayDesc<N>* desc1_out) {
+ TFLITE_DCHECK(desc0_out != nullptr);
+ TFLITE_DCHECK(desc1_out != nullptr);
+
+ // Copy dims to desc.
+ for (int i = 0; i < N; ++i) {
+ desc0_out->extents[i] = input0_dims.sizes[i];
+ desc0_out->strides[i] = input0_dims.strides[i];
+ desc1_out->extents[i] = input1_dims.sizes[i];
+ desc1_out->strides[i] = input1_dims.strides[i];
+ }
+
+ // Walk over each dimension. If the extents are equal do nothing.
+ // Otherwise, set the desc with extent 1 to have extent equal to the other and
+ // stride 0.
+ for (int i = 0; i < N; ++i) {
+ const int extent0 = ArraySize(input0_dims, i);
+ const int extent1 = ArraySize(input1_dims, i);
+ if (extent0 != extent1) {
+ if (extent0 == 1) {
+ desc0_out->strides[i] = 0;
+ desc0_out->extents[i] = extent1;
+ } else {
+ TFLITE_DCHECK_EQ(extent1, 1);
+ desc1_out->strides[i] = 0;
+ desc1_out->extents[i] = extent0;
+ }
+ }
+ }
+}
+
+inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
+ for (int i = 0; i < 4; i++) {
+ if (dims1.sizes[i] != dims2.sizes[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline void AddBiasAndEvalActivationFunction(const float* bias_data,
+ const Dims<4>& bias_dims,
+ float* array_data,
+ const Dims<4>& array_dims,
+ float output_activation_min,
+ float output_activation_max) {
+#ifdef USE_NEON
+ gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
+ const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3];
+ const int array_size = array_dims.sizes[3] * array_dims.strides[3];
+ TFLITE_DCHECK_EQ((array_size % bias_size), 0);
+ float* array_ptr = array_data;
+ float* array_end_ptr = array_ptr + array_size;
+ const auto activation_min = vdupq_n_f32(output_activation_min);
+ const auto activation_max = vdupq_n_f32(output_activation_max);
+ for (; array_ptr != array_end_ptr; array_ptr += bias_size) {
+ int i = 0;
+ for (; i <= bias_size - 16; i += 16) {
+ auto b0 = vld1q_f32(bias_data + i);
+ auto b1 = vld1q_f32(bias_data + i + 4);
+ auto b2 = vld1q_f32(bias_data + i + 8);
+ auto b3 = vld1q_f32(bias_data + i + 12);
+ auto a0 = vld1q_f32(array_ptr + i);
+ auto a1 = vld1q_f32(array_ptr + i + 4);
+ auto a2 = vld1q_f32(array_ptr + i + 8);
+ auto a3 = vld1q_f32(array_ptr + i + 12);
+ auto x0 = vaddq_f32(a0, b0);
+ auto x1 = vaddq_f32(a1, b1);
+ auto x2 = vaddq_f32(a2, b2);
+ auto x3 = vaddq_f32(a3, b3);
+ x0 = vmaxq_f32(activation_min, x0);
+ x1 = vmaxq_f32(activation_min, x1);
+ x2 = vmaxq_f32(activation_min, x2);
+ x3 = vmaxq_f32(activation_min, x3);
+ x0 = vminq_f32(activation_max, x0);
+ x1 = vminq_f32(activation_max, x1);
+ x2 = vminq_f32(activation_max, x2);
+ x3 = vminq_f32(activation_max, x3);
+ vst1q_f32(array_ptr + i, x0);
+ vst1q_f32(array_ptr + i + 4, x1);
+ vst1q_f32(array_ptr + i + 8, x2);
+ vst1q_f32(array_ptr + i + 12, x3);
+ }
+ for (; i <= bias_size - 4; i += 4) {
+ auto b = vld1q_f32(bias_data + i);
+ auto a = vld1q_f32(array_ptr + i);
+ auto x = vaddq_f32(a, b);
+ x = vmaxq_f32(activation_min, x);
+ x = vminq_f32(activation_max, x);
+ vst1q_f32(array_ptr + i, x);
+ }
+ for (; i < bias_size; i++) {
+ array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i],
+ output_activation_min,
+ output_activation_max);
+ }
+ }
+#else // not NEON
+ gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
+ const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3];
+ const int array_size = array_dims.sizes[3] * array_dims.strides[3];
+ TFLITE_DCHECK_EQ((array_size % bias_size), 0);
+ for (int array_offset = 0; array_offset < array_size;
+ array_offset += bias_size) {
+ for (int i = 0; i < bias_size; i++) {
+ array_data[array_offset + i] = ActivationFunctionWithMinMax(
+ array_data[array_offset + i] + bias_data[i], output_activation_min,
+ output_activation_max);
+ }
+ }
+#endif
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AddBiasAndEvalActivationFunction(const float* bias_data,
+ const Dims<4>& bias_dims,
+ float* array_data,
+ const Dims<4>& array_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims,
+ output_activation_min,
+ output_activation_max);
+}
+
+template <typename Lhs, typename Rhs, typename Result>
+void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
+ Eigen::MatrixBase<Result>* result) {
+ if (rhs.cols() == 1) {
+ gemmlowp::ScopedProfilingLabel label("GEMV");
+ result->col(0).noalias() = lhs * rhs.col(0);
+ } else {
+ gemmlowp::ScopedProfilingLabel label("GEMM");
+ result->noalias() = lhs * rhs;
+ }
+}
+
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("FullyConnected");
+ // TODO(b/62193649): this convoluted shape computation (determining
+ // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows)
+ // is because the current --variable_batch hack consists in overwriting the
+ // 3rd dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ // When that is fixed, this should become:
+ // const auto input_matrix_map =
+ // MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ const int input_rows = ArraySize(weights_dims, 0);
+ const auto input_matrix_map =
+ MapAsMatrixWithGivenNumberOfRows(input_data, input_dims, input_rows);
+ const auto filter_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(weights_data, weights_dims);
+ auto output_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+
+ Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
+ AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
+ output_dims, output_activation_min,
+ output_activation_max);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data, const Dims<4>& weights_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
+ bias_dims, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+inline void preload_l1_stream(const uint8* ptr) {
+#ifdef GEMMLOWP_ARM_64
+ asm volatile("prfm pldl1strm, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
+#else
+ gemmlowp::Prefetch(ptr);
+#endif
+}
+
+#ifdef USE_NEON
+inline void FullyConnectedAsGEMV(
+ const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
+ const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int32 output_offset,
+ int32 output_multiplier, int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("FullyConnectedAsGEMV/8bit");
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
+ ArraySize(output_dims, 3),
+ 1);
+ const int input_size = input_dims.strides[3];
+ const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
+ static constexpr int kPeel = 4;
+ for (int k = 0; k < input_size; k += 64) {
+ preload_l1_stream(input_data + k);
+ }
+ for (int k = 0; k < kPeel * input_size; k += 64) {
+ preload_l1_stream(filter_data + k);
+ }
+ TFLITE_DCHECK(!(output_size % kPeel));
+ const int32* bias_ptr = bias_data;
+ uint8* output_ptr = output_data;
+ for (int out = 0; out < output_size; out += kPeel) {
+ int32x4_t acc[kPeel];
+ for (int k = 0; k < kPeel; k++) {
+ acc[k] = vdupq_n_s32(0);
+ }
+ const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
+ const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
+ int in = 0;
+ for (; in <= input_size - 16; in += 16) {
+ const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
+ uint8x16_t filter_val_u8[kPeel];
+ for (int k = 0; k < kPeel; k++) {
+ const uint8* filter_ptr = filter_data + in + (out + k) * input_size;
+ filter_val_u8[k] = vld1q_u8(filter_ptr);
+ preload_l1_stream(filter_ptr + 64);
+ }
+ int16x8_t input_val[2];
+ const uint8x8_t low = vget_low_u8(input_val_u8);
+ const uint8x8_t high = vget_high_u8(input_val_u8);
+ input_val[0] = vreinterpretq_s16_u16(vmovl_u8(low));
+ input_val[1] = vreinterpretq_s16_u16(vmovl_u8(high));
+ input_val[0] = vaddq_s16(input_val[0], input_offset_vec);
+ input_val[1] = vaddq_s16(input_val[1], input_offset_vec);
+ int16x8_t filter_val[kPeel][2];
+ for (int k = 0; k < kPeel; k++) {
+ const uint8x8_t low = vget_low_u8(filter_val_u8[k]);
+ const uint8x8_t high = vget_high_u8(filter_val_u8[k]);
+ filter_val[k][0] = vreinterpretq_s16_u16(vmovl_u8(low));
+ filter_val[k][1] = vreinterpretq_s16_u16(vmovl_u8(high));
+ filter_val[k][0] = vaddq_s16(filter_val[k][0], filter_offset_vec);
+ filter_val[k][1] = vaddq_s16(filter_val[k][1], filter_offset_vec);
+ }
+ for (int p = 0; p < 2; p++) {
+ for (int k = 0; k < kPeel; k++) {
+ acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k][p]),
+ vget_low_s16(input_val[p]));
+ }
+ for (int k = 0; k < kPeel; k++) {
+ acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k][p]),
+ vget_high_s16(input_val[p]));
+ }
+ }
+ }
+ for (; in <= input_size - 8; in += 8) {
+ const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
+ uint8x8_t filter_val_u8[kPeel];
+ for (int k = 0; k < kPeel; k++) {
+ const uint8* filter_ptr = filter_data + in + (out + k) * input_size;
+ filter_val_u8[k] = vld1_u8(filter_ptr);
+ }
+ int16x8_t input_val;
+ input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
+ input_val = vaddq_s16(input_val, input_offset_vec);
+ int16x8_t filter_val[kPeel];
+ for (int k = 0; k < kPeel; k++) {
+ filter_val[k] = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8[k]));
+ filter_val[k] = vaddq_s16(filter_val[k], filter_offset_vec);
+ }
+ for (int k = 0; k < kPeel; k++) {
+ acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k]),
+ vget_low_s16(input_val));
+ }
+ for (int k = 0; k < kPeel; k++) {
+ acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k]),
+ vget_high_s16(input_val));
+ }
+ }
+ if (in < input_size) {
+ int32 buf[4 * kPeel];
+ for (int k = 0; k < 4; k++) {
+ vst1q_s32(buf + 4 * k, acc[k]);
+ }
+ for (; in < input_size; in++) {
+ int lane = (in + 8 - input_size) % 4;
+ const int32 input_val = input_data[in] + input_offset;
+ for (int k = 0; k < kPeel; k++) {
+ int32 filter_val =
+ filter_data[in + (out + k) * input_size] + filter_offset;
+ buf[lane + 4 * k] += filter_val * input_val;
+ }
+ }
+ for (int k = 0; k < 4; k++) {
+ acc[k] = vld1q_s32(buf + 4 * k);
+ }
+ }
+
+ // Horizontally reduce accumulators
+ int32x2_t pairwise_reduced_acc[kPeel];
+ for (int k = 0; k < kPeel; k++) {
+ pairwise_reduced_acc[k] =
+ vpadd_s32(vget_low_s32(acc[k]), vget_high_s32(acc[k]));
+ }
+ static_assert(kPeel == 4, "the code below currently assumes kPeel = 4");
+ const int32x2_t reduced_lo =
+ vpadd_s32(pairwise_reduced_acc[0], pairwise_reduced_acc[1]);
+ const int32x2_t reduced_hi =
+ vpadd_s32(pairwise_reduced_acc[2], pairwise_reduced_acc[3]);
+ int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
+ // Add bias values.
+ int32x4_t bias_vec = vld1q_s32(bias_ptr);
+ bias_ptr += 4;
+ reduced = vaddq_s32(reduced, bias_vec);
+ // Multiply by the fixed-point multiplier.
+ reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
+ // Rounding-shift-right.
+ using gemmlowp::RoundingDivideByPOT;
+ reduced = RoundingDivideByPOT(reduced, output_shift);
+ // Add the output offset.
+ const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
+ reduced = vaddq_s32(reduced, output_offset_vec);
+ // Narrow values down to 16 bit signed.
+ const int16x4_t res16 = vqmovn_s32(reduced);
+ // Narrow values down to 8 bit unsigned, saturating.
+ uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16));
+ // Apply the clamping from the activation function
+ res8 = vmax_u8(res8, vdup_n_u8(output_activation_min));
+ res8 = vmin_u8(res8, vdup_n_u8(output_activation_max));
+ // Store results to destination. Assumes 32bit alignment.
+ vst1_lane_u32(reinterpret_cast<uint32*>(output_ptr),
+ vreinterpret_u32_u8(res8), 0);
+ output_ptr += kPeel;
+ }
+}
+#endif // USE_NEON
+
+struct GemmlowpOutputPipeline {
+ typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
+ ColVectorMap;
+ typedef std::tuple<
+ gemmlowp::OutputStageBiasAddition<ColVectorMap>,
+ gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
+ gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8>
+ Pipeline;
+ static Pipeline Make(const int32* bias_data, int output_rows,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max) {
+ ColVectorMap bias_vector(bias_data, output_rows);
+ gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
+ bias_addition_stage.bias_vector = bias_vector;
+ gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
+ quantize_down_stage;
+ quantize_down_stage.result_offset_after_shift = output_offset;
+ quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
+ quantize_down_stage.result_shift = output_shift;
+ gemmlowp::OutputStageClamp clamp_stage;
+ clamp_stage.min = output_activation_min;
+ clamp_stage.max = output_activation_max;
+ gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage;
+ return std::make_tuple(bias_addition_stage, quantize_down_stage,
+ clamp_stage, saturating_cast_stage);
+ }
+};
+
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit");
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
+ ArraySize(output_dims, 3);
+#ifdef USE_NEON
+ const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
+ if (batches == 1 && !(output_size % 4)) {
+ return FullyConnectedAsGEMV(
+ input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, output_offset, output_multiplier,
+ output_shift, output_activation_min, output_activation_max, output_data,
+ output_dims);
+ }
+#endif // USE_NEON
+ const int filter_rows = filter_dims.sizes[1];
+ const int filter_cols = filter_dims.sizes[0];
+ TFLITE_DCHECK_EQ(filter_dims.sizes[2], 1);
+ TFLITE_DCHECK_EQ(filter_dims.sizes[3], 1);
+ const int output_rows = output_dims.sizes[0];
+ TFLITE_DCHECK_EQ(output_rows, filter_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
+ filter_data, output_rows, filter_cols, filter_cols);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
+ input_data, filter_cols, batches, filter_cols);
+ gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
+ output_data, output_rows, batches, output_rows);
+ const auto& output_pipeline = GemmlowpOutputPipeline::Make(
+ bias_data, output_rows, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max);
+ gemmlowp::GemmWithOutputPipeline<uint8, uint8,
+ gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
+ input_offset, output_pipeline);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims, gemm_context);
+}
+
+template <typename T>
+inline void ExtractPatchIntoBufferColumn(
+ const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int in_width, int in_height, int in_depth, int single_buffer_length,
+ int buffer_id, const T* in_data, T* conv_buffer_data, uint8 byte_zero) {
+ gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn");
+ // This chunk of code reshapes all the inputs corresponding to
+ // output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
+ const int kwidth_times_indepth = kwidth * in_depth;
+ const int inwidth_times_indepth = in_width * in_depth;
+ const int ih_ungated_start = h * stride_height - pad_height;
+ const int ih_ungated_end = (ih_ungated_start + kheight);
+ const int ih_end = std::min(ih_ungated_end, in_height);
+ const int iw_ungated_start = w * stride_width - pad_width;
+ const int iw_ungated_end = (iw_ungated_start + kwidth);
+ const int iw_end = std::min(iw_ungated_end, in_width);
+ // If the patch is off the edge of the input image, skip writing those rows
+ // and columns from the patch into the output array.
+ const int h_offset = std::max(0, -ih_ungated_start);
+ const int w_offset = std::max(0, -iw_ungated_start);
+ const int ih_start = std::max(0, ih_ungated_start);
+ const int iw_start = std::max(0, iw_ungated_start);
+ const int single_row_num =
+ std::min(kwidth - w_offset, in_width - iw_start) * in_depth;
+ const int output_row_offset = (buffer_id * single_buffer_length);
+ int out_offset =
+ output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
+ int in_offset = Offset(input_dims, 0, iw_start, ih_start, b);
+
+ // Express all of the calculations as padding around the input patch.
+ const int top_padding = h_offset;
+ const int bottom_padding = (ih_ungated_end - ih_end);
+ const int left_padding = w_offset;
+ const int right_padding = (iw_ungated_end - iw_end);
+ assert(single_row_num ==
+ ((kwidth - (left_padding + right_padding)) * in_depth));
+
+ // Write out zeroes to the elements representing the top rows of the input
+ // patch that are off the edge of the input image.
+ if (top_padding > 0) {
+ const int top_row_elements = (top_padding * kwidth * in_depth);
+ memset(conv_buffer_data + output_row_offset, byte_zero,
+ (top_row_elements * sizeof(T)));
+ }
+
+ // If the patch is on the interior of the input image horizontally, just copy
+ // over the rows sequentially, otherwise add zero padding at the start or end.
+ if ((left_padding == 0) && (right_padding == 0)) {
+ for (int ih = ih_start; ih < ih_end; ++ih) {
+ memcpy(conv_buffer_data + out_offset, in_data + in_offset,
+ single_row_num * sizeof(T));
+ out_offset += kwidth_times_indepth;
+ in_offset += inwidth_times_indepth;
+ }
+ } else {
+ for (int ih = ih_start; ih < ih_end; ++ih) {
+ if (left_padding > 0) {
+ const int left_start = (out_offset - (left_padding * in_depth));
+ memset(conv_buffer_data + left_start, byte_zero,
+ (left_padding * in_depth * sizeof(T)));
+ }
+ memcpy(conv_buffer_data + out_offset, in_data + in_offset,
+ single_row_num * sizeof(T));
+ if (right_padding > 0) {
+ const int right_start = (out_offset + single_row_num);
+ memset(conv_buffer_data + right_start, byte_zero,
+ (right_padding * in_depth * sizeof(T)));
+ }
+ out_offset += kwidth_times_indepth;
+ in_offset += inwidth_times_indepth;
+ }
+ }
+
+ // If the bottom of the patch falls off the input image, pad the values
+ // representing those input rows with zeroes.
+ if (bottom_padding > 0) {
+ const int bottom_row_elements = (bottom_padding * kwidth * in_depth);
+ const int bottom_start =
+ output_row_offset +
+ ((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
+ memset(conv_buffer_data + bottom_start, byte_zero,
+ (bottom_row_elements * sizeof(T)));
+ }
+}
+
+template <typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, int kheight,
+ int kwidth, uint8 byte_zero, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Im2col");
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_height = ArraySize(input_dims, 2);
+ const int output_depth = ArraySize(output_dims, 0);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+
+ int buffer_id = 0;
+ // Loop over the output nodes.
+ for (int b = 0; b < batches; ++b) {
+ for (int h = 0; h < output_height; ++h) {
+ for (int w = 0; w < output_width; ++w) {
+ ExtractPatchIntoBufferColumn(
+ input_dims, w, h, b, kheight, kwidth, stride_width, stride_height,
+ pad_width, pad_height, input_width, input_height, input_depth,
+ output_depth, buffer_id, input_data, output_data, byte_zero);
+ ++buffer_id;
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int kheight, int kwidth,
+ uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
+ Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
+ kwidth, byte_zero, output_data, output_dims);
+}
+
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ (void)im2col_data;
+ (void)im2col_dims;
+ gemmlowp::ScopedProfilingLabel label("Conv");
+
+ const float* gemm_input_data = nullptr;
+ const Dims<4>* gemm_input_dims = nullptr;
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const bool need_im2col = stride_width != 1 || stride_height != 1 ||
+ filter_width != 1 || filter_height != 1;
+ if (need_im2col) {
+ TFLITE_DCHECK(im2col_data);
+ Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_height, filter_width, 0, im2col_data,
+ im2col_dims);
+ gemm_input_data = im2col_data;
+ gemm_input_dims = &im2col_dims;
+ } else {
+ // TODO(aselle): We need to make sure to not send im2col if it is not
+ // needed.
+ TFLITE_DCHECK(!im2col_data);
+ gemm_input_data = input_data;
+ gemm_input_dims = &input_dims;
+ }
+
+ const auto im2col_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(gemm_input_data, *gemm_input_dims);
+ const auto filter_matrix_map =
+ MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ auto output_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+
+ Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
+
+ AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
+ output_dims, output_activation_min,
+ output_activation_max);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, pad_width, pad_height,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, pad_width, pad_height, output_data,
+ output_dims, im2col_data, im2col_dims);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label("Conv/8bit");
+
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+
+ const uint8* gemm_input_data = nullptr;
+ const Dims<4>* gemm_input_dims = nullptr;
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const bool need_im2col = stride_width != 1 || stride_height != 1 ||
+ filter_width != 1 || filter_height != 1;
+ if (need_im2col) {
+ TFLITE_DCHECK(im2col_data);
+ const int input_zero_point = -input_offset;
+ TFLITE_DCHECK_GE(input_zero_point, 0);
+ TFLITE_DCHECK_LE(input_zero_point, 255);
+ Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_height, filter_width, input_zero_point,
+ im2col_data, im2col_dims);
+ gemm_input_data = im2col_data;
+ gemm_input_dims = &im2col_dims;
+ } else {
+ TFLITE_DCHECK(!im2col_data);
+ gemm_input_data = input_data;
+ gemm_input_dims = &input_dims;
+ }
+
+ const int gemm_input_rows = gemm_input_dims->sizes[0];
+ const int gemm_input_cols = gemm_input_dims->sizes[1] *
+ gemm_input_dims->sizes[2] *
+ gemm_input_dims->sizes[3];
+ const int filter_rows = filter_dims.sizes[3];
+ const int filter_cols =
+ filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
+ const int output_rows = output_dims.sizes[0];
+ const int output_cols =
+ output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+ TFLITE_DCHECK_EQ(output_rows, filter_rows);
+ TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
+ TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
+ filter_data, filter_rows, filter_cols);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
+ gemm_input_data, gemm_input_rows, gemm_input_cols);
+ gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
+ output_data, output_rows, output_cols);
+ const auto& output_pipeline = GemmlowpOutputPipeline::Make(
+ bias_data, output_rows, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max);
+ gemmlowp::GemmWithOutputPipeline<uint8, uint8,
+ gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
+ input_offset, output_pipeline);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, stride, pad_width,
+ pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+template <typename T>
+inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("DepthToSpace");
+
+ const int input_depth = ArraySize(input_dims, 0);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_height = ArraySize(input_dims, 2);
+
+ const int output_depth = ArraySize(output_dims, 0);
+ const int batch_size = ArraySize(output_dims, 3);
+
+ // Number of continuous values that we can copy in one interation.
+ const int stride = block_size * output_depth;
+
+ for (int batch = 0; batch < batch_size; ++batch) {
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ const T* input_ptr = input_data + Offset(input_dims, 0, 0, in_h, batch);
+ for (int offset_h = 0; offset_h < block_size; ++offset_h) {
+ const T* src = input_ptr;
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ memcpy(output_data, src, stride * sizeof(T));
+ output_data += stride;
+ src += input_depth;
+ }
+ input_ptr += stride;
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int kheight, int kwidth,
+ uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
+ Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
+ kwidth, byte_zero, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("ConvAsGemm");
+
+ const auto input_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ const auto filter_matrix_map =
+ MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ auto output_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+
+ Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
+
+ AddBiasAndEvalActivationFunction<Ac>(bias_data, bias_dims, output_data,
+ output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit");
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ const int input_rows = input_dims.sizes[0];
+ const int input_cols =
+ input_dims.sizes[1] * input_dims.sizes[2] * input_dims.sizes[3];
+ const int filter_rows = filter_dims.sizes[3];
+ const int filter_cols =
+ filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
+ const int output_rows = output_dims.sizes[0];
+ const int output_cols =
+ output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+ TFLITE_DCHECK_EQ(output_rows, filter_rows);
+ TFLITE_DCHECK_EQ(output_cols, input_cols);
+ TFLITE_DCHECK_EQ(filter_cols, input_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
+ filter_data, output_rows, filter_cols, filter_cols);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
+ input_data, filter_cols, output_cols, filter_cols);
+ gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
+ output_data, output_rows, output_cols, output_rows);
+ const auto& output_pipeline = GemmlowpOutputPipeline::Make(
+ bias_data, output_rows, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max);
+ gemmlowp::GemmWithOutputPipeline<uint8, uint8,
+ gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
+ input_offset, output_pipeline);
+}
+
+template <typename T>
+inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("SpaceToDepth");
+
+ const int output_depth = ArraySize(output_dims, 0);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+
+ const int input_depth = ArraySize(input_dims, 0);
+ const int batch_size = ArraySize(input_dims, 3);
+
+ // Number of continuous values that we can copy in one interation.
+ const int stride = block_size * input_depth;
+
+ for (int batch = 0; batch < batch_size; ++batch) {
+ for (int out_h = 0; out_h < output_height; ++out_h) {
+ T* output_ptr = output_data + Offset(output_dims, 0, 0, out_h, batch);
+ for (int offset_h = 0; offset_h < block_size; ++offset_h) {
+ T* dst = output_ptr;
+ for (int out_w = 0; out_w < output_width; ++out_w) {
+ memcpy(dst, input_data, stride * sizeof(T));
+ input_data += stride;
+ dst += output_depth;
+ }
+ output_ptr += stride;
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+void NonGlobalBatchNormalization(
+ const float* input_data, const Dims<4>& input_dims, const float* mean_data,
+ const Dims<4>& mean_dims, const float* multiplier_data,
+ const Dims<4>& multiplier_dims, const float* offset_data,
+ const Dims<4>& offset_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("NonGlobalBatchNormalization");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height =
+ MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2,
+ offset_dims, 2, output_dims, 2);
+ const int width =
+ MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1,
+ offset_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
+ offset_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ (input_data[Offset(input_dims, c, x, y, b)] -
+ mean_data[Offset(mean_dims, c, x, y, 0)]) *
+ multiplier_data[Offset(multiplier_dims, c, x, y, 0)] +
+ offset_data[Offset(offset_dims, c, x, y, 0)]);
+ }
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+void GlobalBatchNormalization(const float* input_data,
+ const Dims<4>& input_dims, const float* mean_data,
+ const Dims<4>& mean_dims,
+ const float* multiplier_data,
+ const Dims<4>& multiplier_dims,
+ const float* offset_data,
+ const Dims<4>& offset_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("GlobalBatchNormalization");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
+ offset_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ (input_data[Offset(input_dims, c, x, y, b)] -
+ mean_data[Offset(mean_dims, c, 0, 0, 0)]) *
+ multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] +
+ offset_data[Offset(offset_dims, c, 0, 0, 0)]);
+ }
+ }
+ }
+ }
+}
+
+inline void Relu(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Relu (not fused)");
+
+ const auto input = MapAsVector(input_data, input_dims);
+ auto output = MapAsVector(output_data, output_dims);
+ output = input.cwiseMax(0.0f);
+}
+
+inline void Relu1(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ const float upper = 1;
+ const float lower = -1;
+ float clamped = val > upper ? upper : val < lower ? lower : val;
+ output_data[Offset(output_dims, c, x, y, b)] = clamped;
+ }
+ }
+ }
+ }
+}
+
+inline void Relu6(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ const float upper = 6;
+ const float lower = 0;
+ float clamped = val > upper ? upper : val < lower ? lower : val;
+ output_data[Offset(output_dims, c, x, y, b)] = clamped;
+ }
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("L2Normalization");
+ static_assert(Ac == FusedActivationFunctionType::kNone, "");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ float squared_l2_norm = 0;
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ squared_l2_norm += val * val;
+ }
+ float inverse_l2_norm = 1.0f / std::sqrt(squared_l2_norm);
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ input_data[Offset(input_dims, c, x, y, b)] * inverse_l2_norm;
+ }
+ }
+ }
+ }
+}
+
+inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt,
+ int* output_shift) {
+ *output_shift = 11;
+ while (input >= (1 << 29)) {
+ input /= 4;
+ ++*output_shift;
+ }
+ TFLITE_DCHECK_GT(input, 0);
+ const unsigned max_left_shift_bits = __builtin_clz(input) - 1;
+ const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
+ const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
+ *output_shift -= left_shift_bit_pairs;
+ input <<= 2 * left_shift_bit_pairs;
+ TFLITE_DCHECK_GE(input, (1 << 27));
+ TFLITE_DCHECK_LT(input, (1 << 29));
+ using gemmlowp::FixedPoint;
+ using gemmlowp::Rescale;
+ using gemmlowp::SaturatingRoundingMultiplyByPOT;
+ // Using 3 integer bits gives us enough room for the internal arithmetic in
+ // this Newton-Raphson iteration.
+ using F3 = FixedPoint<int32, 3>;
+ using F0 = FixedPoint<int32, 0>;
+ const F3 fixedpoint_input = F3::FromRaw(input >> 1);
+ const F3 fixedpoint_half_input =
+ SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
+ const F3 fixedpoint_half_three =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
+ // Newton-Raphson iteration
+ // Naive unoptimized starting guess: x = 1
+ F3 x = F3::One();
+ // Naive unoptimized number of iterations: 5
+ for (int i = 0; i < 5; i++) {
+ const F3 x3 = Rescale<3>(x * x * x);
+ x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
+ }
+ const F0 fixedpoint_half_sqrt_2 =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
+ x = x * fixedpoint_half_sqrt_2;
+ *output_inv_sqrt = x.raw();
+ if (*output_shift < 0) {
+ *output_inv_sqrt <<= -*output_shift;
+ *output_shift = 0;
+ }
+}
+
+inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ TFLITE_DCHECK_EQ(batches, 1);
+ TFLITE_DCHECK_EQ(height, 1);
+ TFLITE_DCHECK_EQ(width, 1);
+ int32 square_l2_norm = 0;
+ for (int i = 0; i < depth; i++) {
+ int32 diff = input_data[i] - input_zero_point;
+ square_l2_norm += diff * diff;
+ }
+ int32 inv_l2norm_multiplier;
+ int inv_l2norm_shift;
+ GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier,
+ &inv_l2norm_shift);
+
+ for (int i = 0; i < depth; i++) {
+ int32 diff = input_data[i] - input_zero_point;
+ int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne(
+ 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
+ int32 unclamped_output_val = 128 + rescaled_diff;
+ int32 output_val = std::min(255, std::max(0, unclamped_output_val));
+ output_data[i] = static_cast<uint8>(output_val);
+ }
+}
+
+inline void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Add");
+ /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3,
+ output_dims, 3);
+ /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2,
+ output_dims, 2);
+ /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1,
+ output_dims, 1);
+ /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0,
+ output_dims, 0);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+
+ int i = 0;
+ const int size = input1_dims.sizes[3] * input1_dims.strides[3];
+#ifdef USE_NEON
+ const auto activation_min = vdupq_n_f32(output_activation_min);
+ const auto activation_max = vdupq_n_f32(output_activation_max);
+ for (; i <= size - 16; i += 16) {
+ auto a10 = vld1q_f32(input1_data + i);
+ auto a11 = vld1q_f32(input1_data + i + 4);
+ auto a12 = vld1q_f32(input1_data + i + 8);
+ auto a13 = vld1q_f32(input1_data + i + 12);
+ auto a20 = vld1q_f32(input2_data + i);
+ auto a21 = vld1q_f32(input2_data + i + 4);
+ auto a22 = vld1q_f32(input2_data + i + 8);
+ auto a23 = vld1q_f32(input2_data + i + 12);
+ auto x0 = vaddq_f32(a10, a20);
+ auto x1 = vaddq_f32(a11, a21);
+ auto x2 = vaddq_f32(a12, a22);
+ auto x3 = vaddq_f32(a13, a23);
+ x0 = vmaxq_f32(activation_min, x0);
+ x1 = vmaxq_f32(activation_min, x1);
+ x2 = vmaxq_f32(activation_min, x2);
+ x3 = vmaxq_f32(activation_min, x3);
+ x0 = vminq_f32(activation_max, x0);
+ x1 = vminq_f32(activation_max, x1);
+ x2 = vminq_f32(activation_max, x2);
+ x3 = vminq_f32(activation_max, x3);
+ vst1q_f32(output_data + i, x0);
+ vst1q_f32(output_data + i + 4, x1);
+ vst1q_f32(output_data + i + 8, x2);
+ vst1q_f32(output_data + i + 12, x3);
+ }
+ for (; i <= size - 4; i += 4) {
+ auto a1 = vld1q_f32(input1_data + i);
+ auto a2 = vld1q_f32(input2_data + i);
+ auto x = vaddq_f32(a1, a2);
+ x = vmaxq_f32(activation_min, x);
+ x = vminq_f32(activation_max, x);
+ vst1q_f32(output_data + i, x);
+ }
+#endif // NEON
+
+ for (; i < size; i++) {
+ auto x = input1_data[i] + input2_data[i];
+ output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
+ output_activation_max);
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier, int input2_shift,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ gemmlowp::ScopedProfilingLabel label("Add/8bit");
+ /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3,
+ output_dims, 3);
+ /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2,
+ output_dims, 2);
+ /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1,
+ output_dims, 1);
+ /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0,
+ output_dims, 0);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+
+ int i = 0;
+ const int size = input1_dims.sizes[3] * input1_dims.strides[3];
+ TFLITE_DCHECK_GT(input1_offset, -256);
+ TFLITE_DCHECK_GT(input2_offset, -256);
+ TFLITE_DCHECK_LT(input1_offset, 256);
+ TFLITE_DCHECK_LT(input2_offset, 256);
+#ifdef USE_NEON
+ for (; i <= size - 8; i += 8) {
+ const auto input1_val_original = vld1_u8(input1_data + i);
+ const auto input2_val_original = vld1_u8(input2_data + i);
+ const auto input1_val_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
+ const auto input2_val_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
+ const auto input1_val =
+ vaddq_s16(input1_val_s16, vdupq_n_s16(input1_offset));
+ const auto input2_val =
+ vaddq_s16(input2_val_s16, vdupq_n_s16(input2_offset));
+ const auto input1_val_high = vget_high_s16(input1_val);
+ const auto input1_val_low = vget_low_s16(input1_val);
+ const auto input2_val_high = vget_high_s16(input2_val);
+ const auto input2_val_low = vget_low_s16(input2_val);
+ auto x11 = vmovl_s16(input1_val_low);
+ auto x12 = vmovl_s16(input1_val_high);
+ auto x21 = vmovl_s16(input2_val_low);
+ auto x22 = vmovl_s16(input2_val_high);
+ const auto left_shift_dup = vdupq_n_s32(left_shift);
+ x11 = vshlq_s32(x11, left_shift_dup);
+ x12 = vshlq_s32(x12, left_shift_dup);
+ x21 = vshlq_s32(x21, left_shift_dup);
+ x22 = vshlq_s32(x22, left_shift_dup);
+ x11 = vqrdmulhq_n_s32(x11, input1_multiplier);
+ x12 = vqrdmulhq_n_s32(x12, input1_multiplier);
+ x21 = vqrdmulhq_n_s32(x21, input2_multiplier);
+ x22 = vqrdmulhq_n_s32(x22, input2_multiplier);
+ const auto input1_shift_dup = vdupq_n_s32(-input1_shift);
+ const auto input2_shift_dup = vdupq_n_s32(-input2_shift);
+ x11 = vshlq_s32(x11, input1_shift_dup);
+ x12 = vshlq_s32(x12, input1_shift_dup);
+ x21 = vshlq_s32(x21, input2_shift_dup);
+ x22 = vshlq_s32(x22, input2_shift_dup);
+ auto s1 = vaddq_s32(x11, x21);
+ auto s2 = vaddq_s32(x12, x22);
+ s1 = vqrdmulhq_n_s32(s1, output_multiplier);
+ s2 = vqrdmulhq_n_s32(s2, output_multiplier);
+ using gemmlowp::RoundingDivideByPOT;
+ s1 = RoundingDivideByPOT(s1, output_shift);
+ s2 = RoundingDivideByPOT(s2, output_shift);
+ const auto s1_narrowed = vmovn_s32(s1);
+ const auto s2_narrowed = vmovn_s32(s2);
+ const auto s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
+ vdupq_n_s16(output_offset));
+ vst1_u8(output_data + i, vqmovun_s16(s));
+ }
+#endif // NEON
+
+ for (; i < size; i++) {
+ const int32 input1_val = input1_offset + input1_data[i];
+ const int32 input2_val = input2_offset + input2_data[i];
+ const int32 shifted_input1_val = input1_val * (1 << left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input1_val, input1_multiplier, input1_shift);
+ const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input2_val, input2_multiplier, input2_shift);
+ const int32 raw_sum = scaled_input1_val + scaled_input2_val;
+ const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne(
+ raw_sum, output_multiplier, output_shift) +
+ output_offset;
+ const int32 clamped_output = std::min(
+ output_activation_max, std::max(output_activation_min, raw_output));
+ output_data[i] = static_cast<uint8>(clamped_output);
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+void Add(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Add/int32");
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+
+ auto input1_map = MapAsVector(input1_data, input1_dims);
+ auto input2_map = MapAsVector(input2_data, input2_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ if (AreSameDims(input1_dims, input2_dims)) {
+ output_map.array() = input1_map.array() + input2_map.array();
+ } else if (RequiredBufferSizeForDims(input2_dims) == 1) {
+ auto scalar = input2_data[0];
+ output_map.array() = input1_map.array() + scalar;
+ } else if (RequiredBufferSizeForDims(input1_dims) == 1) {
+ auto scalar = input1_data[0];
+ output_map.array() = scalar + input2_map.array();
+ } else {
+ // Should not come here.
+ TFLITE_DCHECK(false);
+ }
+}
+
+// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
+// dimensionality if the runtime code does a single loop over one dimension
+// that handles broadcasting as the base case. The code generator would then
+// generate max(D1, D2) nested for loops.
+// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from
+// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
+// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
+// reference_ops.h.
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ const int32 input1_val =
+ input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ const int32 input2_val =
+ input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ const int32 shifted_input1_val = input1_val * (1 << left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input1_val, input1_multiplier, input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input2_val, input2_multiplier, input2_shift);
+ const int32 raw_sum = scaled_input1_val + scaled_input2_val;
+ const int32 raw_output =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ raw_sum, output_multiplier, output_shift) +
+ output_offset;
+ const int32 clamped_output =
+ std::min(output_activation_max,
+ std::max(output_activation_min, raw_output));
+ output_data[Offset(output_dims, c, x, y, b)] =
+ static_cast<uint8>(clamped_output);
+ }
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset,
+ input1_multiplier, input1_shift, input2_data, input2_dims,
+ input2_offset, input2_multiplier, input2_shift, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Mul");
+ /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3,
+ output_dims, 3);
+ /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2,
+ output_dims, 2);
+ /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1,
+ output_dims, 1);
+ /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0,
+ output_dims, 0);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+
+ int i = 0;
+ const int size = input1_dims.sizes[3] * input1_dims.strides[3];
+#ifdef USE_NEON
+ const auto activation_min = vdupq_n_f32(output_activation_min);
+ const auto activation_max = vdupq_n_f32(output_activation_max);
+ for (; i <= size - 16; i += 16) {
+ auto a10 = vld1q_f32(input1_data + i);
+ auto a11 = vld1q_f32(input1_data + i + 4);
+ auto a12 = vld1q_f32(input1_data + i + 8);
+ auto a13 = vld1q_f32(input1_data + i + 12);
+ auto a20 = vld1q_f32(input2_data + i);
+ auto a21 = vld1q_f32(input2_data + i + 4);
+ auto a22 = vld1q_f32(input2_data + i + 8);
+ auto a23 = vld1q_f32(input2_data + i + 12);
+ auto x0 = vmulq_f32(a10, a20);
+ auto x1 = vmulq_f32(a11, a21);
+ auto x2 = vmulq_f32(a12, a22);
+ auto x3 = vmulq_f32(a13, a23);
+
+ x0 = vmaxq_f32(activation_min, x0);
+ x1 = vmaxq_f32(activation_min, x1);
+ x2 = vmaxq_f32(activation_min, x2);
+ x3 = vmaxq_f32(activation_min, x3);
+ x0 = vminq_f32(activation_max, x0);
+ x1 = vminq_f32(activation_max, x1);
+ x2 = vminq_f32(activation_max, x2);
+ x3 = vminq_f32(activation_max, x3);
+
+ vst1q_f32(output_data + i, x0);
+ vst1q_f32(output_data + i + 4, x1);
+ vst1q_f32(output_data + i + 8, x2);
+ vst1q_f32(output_data + i + 12, x3);
+ }
+ for (; i <= size - 4; i += 4) {
+ auto a1 = vld1q_f32(input1_data + i);
+ auto a2 = vld1q_f32(input2_data + i);
+ auto x = vmulq_f32(a1, a2);
+
+ x = vmaxq_f32(activation_min, x);
+ x = vminq_f32(activation_max, x);
+
+ vst1q_f32(output_data + i, x);
+ }
+#endif // NEON
+
+ for (; i < size; i++) {
+ auto x = input1_data[i] * input2_data[i];
+ output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
+ output_activation_max);
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+void Mul(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Mul/int32");
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+
+ auto input1_map = MapAsVector(input1_data, input1_dims);
+ auto input2_map = MapAsVector(input2_data, input2_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ if (AreSameDims(input1_dims, input2_dims)) {
+ output_map.array() = input1_map.array() * input2_map.array();
+ } else if (RequiredBufferSizeForDims(input2_dims) == 1) {
+ auto scalar = input2_data[0];
+ output_map.array() = input1_map.array() * scalar;
+ } else if (RequiredBufferSizeForDims(input1_dims) == 1) {
+ auto scalar = input1_data[0];
+ output_map.array() = scalar * input2_map.array();
+ } else {
+ // Should not come here.
+ TFLITE_DCHECK(false);
+ }
+}
+
+// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
+// dimensionality if the runtime code does a single loop over one dimension
+// that handles broadcasting as the base case. The code generator would then
+// generate max(D1, D2) nested for loops.
+// TODO(benoitjacob): BroadcastMul is intentionally duplicated from
+// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
+// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
+// reference_ops.h.
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMul");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ const int32 input1_val =
+ input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ const int32 input2_val =
+ input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ const int32 unclamped_result =
+ output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ input1_val * input2_val, output_multiplier, output_shift);
+ const int32 clamped_output =
+ std::min(output_activation_max,
+ std::max(output_activation_min, unclamped_result));
+ output_data[Offset(output_dims, c, x, y, b)] =
+ static_cast<uint8>(clamped_output);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
+ input2_dims, input2_offset, output_offset, output_multiplier,
+ output_shift, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void Concatenation(int concat_dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Concatenation");
+ int concat_size = 0;
+ for (int i = 0; i < inputs_count; i++) {
+ for (int j = 0; j < 4; j++) {
+ if (j != concat_dim) {
+ MatchingArraySize(*input_dims[i], j, output_dims, j);
+ }
+ }
+ concat_size += ArraySize(*input_dims[i], concat_dim);
+ }
+ TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ // for now we dont have a model with a Concatenation
+ // with fused activation function.
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ int outer_size = 1;
+ for (int i = concat_dim + 1; i < 4; i++) {
+ outer_size *= output_dims.sizes[i];
+ }
+ Scalar* output_ptr = output_data;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < inputs_count; ++i) {
+ const int copy_size =
+ input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
+ memcpy(output_ptr, input_data[i] + k * copy_size,
+ copy_size * sizeof(Scalar));
+ output_ptr += copy_size;
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void DepthConcatenation(const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ Concatenation<Ac, Scalar>(0, input_data, input_dims, inputs_count,
+ output_data, output_dims);
+}
+
+inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
+ const float* prev_activ_data,
+ const Dims<4>& prev_activ_dims, const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims, const float* prev_state_data,
+ const Dims<4>& prev_state_dims, float* output_state_data,
+ const Dims<4>& output_state_dims, float* output_activ_data,
+ const Dims<4>& output_activ_dims, float* concat_temp_data,
+ const Dims<4>& concat_temp_dims, float* activ_temp_data,
+ const Dims<4>& activ_temp_dims) {
+ gemmlowp::ScopedProfilingLabel label("LstmCell");
+ MatchingArraySize( // batches
+ input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, output_state_dims,
+ 3, output_activ_dims, 3);
+ MatchingArraySize( // height
+ input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, output_state_dims,
+ 2, output_activ_dims, 2);
+ MatchingArraySize( // width
+ input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, output_state_dims,
+ 1, output_activ_dims, 1);
+ TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
+ TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ const int total_input_depth = prev_activ_depth + input_depth;
+ TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
+ TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
+ 1);
+ const int intern_activ_depth =
+ MatchingArraySize(weights_dims, 1, bias_dims, 0);
+ TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ const int output_depth =
+ MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
+ output_state_dims, 0, output_activ_dims, 0);
+ TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
+
+ // Concatenate prev_activ and input data together
+ std::vector<float const*> concat_input_arrays_data;
+ std::vector<Dims<4> const*> concat_input_arrays_dims;
+ concat_input_arrays_data.push_back(input_data);
+ concat_input_arrays_data.push_back(prev_activ_data);
+ concat_input_arrays_dims.push_back(&input_dims);
+ concat_input_arrays_dims.push_back(&prev_activ_dims);
+ Concatenation<FusedActivationFunctionType::kNone, float>(
+ 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]),
+ concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims);
+
+ // Fully connected
+ FullyConnected<FusedActivationFunctionType::kNone>(
+ concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data,
+ bias_dims, activ_temp_data, activ_temp_dims);
+
+ // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
+ // operations.
+ ArrayMap<float> activ_temp_map =
+ MapAsArrayWithFirstDimAsRows(activ_temp_data, activ_temp_dims);
+ auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
+ activ_temp_map.cols());
+ auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
+ activ_temp_map.cols());
+ auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
+ activ_temp_map.cols());
+ auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
+ activ_temp_map.cols());
+ ArrayMap<const float> prev_state_map =
+ MapAsArrayWithFirstDimAsRows(prev_state_data, prev_state_dims);
+ ArrayMap<float> output_state_map =
+ MapAsArrayWithFirstDimAsRows(output_state_data, output_state_dims);
+ ArrayMap<float> output_activ_map =
+ MapAsArrayWithFirstDimAsRows(output_activ_data, output_activ_dims);
+
+ // Combined memory state and final output calculation
+ gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput");
+ output_state_map =
+ input_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) *
+ new_input_sm.tanh() +
+ forget_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) *
+ prev_state_map;
+ output_activ_map =
+ output_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) *
+ output_state_map.tanh();
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
+ int outputs_count, Scalar* const* output_data,
+ const Dims<4>* const* output_dims) {
+ gemmlowp::ScopedProfilingLabel label("TensorFlowSplit");
+ TFLITE_DCHECK_GE(outputs_count, 1);
+ for (int i = 0; i < outputs_count; i++) {
+ /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
+ /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
+ /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
+ }
+ const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3);
+ const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2);
+ const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ // for now we dont have a model with a TensorFlowSplit
+ // with fused activation function.
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ const int whb = width * height * batches;
+ const Scalar* input_ptr = input_data;
+ for (int k = 0; k < whb; k++) {
+ for (int i = 0; i < outputs_count; ++i) {
+ memcpy(output_data[i] + k * output_dims[i]->sizes[0], input_ptr,
+ output_dims[i]->sizes[0] * sizeof(Scalar));
+ input_ptr += output_dims[i]->sizes[0];
+ }
+ }
+}
+
+inline int NodeOffset(int b, int h, int w, int height, int width) {
+ return (b * height + h) * width + w;
+}
+
+inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("AveragePool");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ // TODO(benoitjacob) make this a proper reference impl without Eigen!
+ const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ // TODO(benoitjacob) get rid of the dynamic memory allocation here!
+ Eigen::VectorXf out_count(out_mat.cols());
+ out_count.setZero();
+ // Prefill the output to 0.
+ out_mat.setZero();
+ for (int b = 0; b < batches; ++b) {
+ for (int h = 0; h < input_height; ++h) {
+ for (int w = 0; w < input_width; ++w) {
+ // (h_start, h_end) * (w_start, w_end) is the range that the input
+ // vector projects to.
+ int hpad = h + pad_height;
+ int wpad = w + pad_width;
+ int h_start =
+ (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1;
+ int h_end = std::min(hpad / stride_height + 1, output_height);
+ int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1;
+ int w_end = std::min(wpad / stride_width + 1, output_width);
+ // compute elementwise sum
+ for (int ph = h_start; ph < h_end; ++ph) {
+ for (int pw = w_start; pw < w_end; ++pw) {
+ int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
+ out_mat.col(out_offset) +=
+ in_mat.col(NodeOffset(b, h, w, input_height, input_width));
+ out_count(out_offset)++;
+ }
+ }
+ }
+ }
+ }
+ // Divide the output by the actual number of elements being averaged over
+ TFLITE_DCHECK_GT(out_count.minCoeff(), 0);
+ out_mat.array().rowwise() /= out_count.transpose().array();
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < output_height; ++y) {
+ for (int x = 0; x < output_width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ ActivationFunctionWithMinMax(
+ output_data[Offset(output_dims, c, x, y, b)],
+ output_activation_min, output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("AveragePool/8bit");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ const int filter_count =
+ (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
+ // 1280 required by Inception v3
+ static constexpr int kAccBufferMaxSize = 2048;
+ TFLITE_DCHECK_LE(depth, kAccBufferMaxSize);
+ uint16 acc[kAccBufferMaxSize];
+ memset(acc, 0, depth * sizeof(acc[0]));
+ const uint8* input_ptr =
+ input_data + input_dims.strides[1] * in_x_origin +
+ input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch;
+ for (int fy = filter_y_start; fy < filter_y_end; fy++) {
+ const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] +
+ filter_x_start * input_dims.strides[1];
+ for (int fx = filter_x_start; fx < filter_x_end; fx++) {
+ int channel = 0;
+#ifdef USE_NEON
+ for (; channel <= depth - 16; channel += 16) {
+ uint16x8_t acc_reg[2];
+ for (int i = 0; i < 2; i++) {
+ acc_reg[i] = vld1q_u16(acc + channel + 8 * i);
+ }
+ uint8x16_t input_reg = vld1q_u8(input_row_ptr);
+ input_row_ptr += 16;
+ acc_reg[0] = vaddw_u8(acc_reg[0], vget_low_u8(input_reg));
+ acc_reg[1] = vaddw_u8(acc_reg[1], vget_high_u8(input_reg));
+ for (int i = 0; i < 2; i++) {
+ vst1q_u16(acc + channel + 8 * i, acc_reg[i]);
+ }
+ }
+ for (; channel <= depth - 8; channel += 8) {
+ uint16x8_t acc_reg = vld1q_u16(acc + channel);
+ uint8x8_t input_reg = vld1_u8(input_row_ptr);
+ input_row_ptr += 8;
+ acc_reg = vaddw_u8(acc_reg, input_reg);
+ vst1q_u16(acc + channel, acc_reg);
+ }
+#endif
+ for (; channel < depth; ++channel) {
+ acc[channel] += *input_row_ptr++;
+ }
+ }
+ }
+ uint8* output_ptr =
+ output_data + Offset(output_dims, 0, out_x, out_y, batch);
+ int channel = 0;
+#ifdef USE_NEON
+#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \
+ if (filter_count == FILTER_COUNT) { \
+ for (; channel <= depth - 8; channel += 8) { \
+ uint16 buf[8]; \
+ for (int i = 0; i < 8; i++) { \
+ buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \
+ } \
+ uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \
+ buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); \
+ buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); \
+ vst1_u8(output_ptr + channel, buf8); \
+ } \
+ }
+ AVGPOOL_DIVIDING_BY(9)
+ AVGPOOL_DIVIDING_BY(15)
+#undef AVGPOOL_DIVIDING_BY
+ for (; channel <= depth - 8; channel += 8) {
+ uint16 buf[8];
+ for (int i = 0; i < 8; i++) {
+ buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
+ }
+ uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
+ buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max));
+ buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min));
+ vst1_u8(output_ptr + channel, buf8);
+ }
+#endif
+ for (; channel < depth; ++channel) {
+ uint16 a = (acc[channel] + filter_count / 2) / filter_count;
+ a = std::max<uint16>(a, output_activation_min);
+ a = std::min<uint16>(a, output_activation_max);
+ output_ptr[channel] = static_cast<uint8>(a);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("MaxPool");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ // Prefill the output to minimum representable float value
+ out_mat.setConstant(std::numeric_limits<float>::lowest());
+ for (int b = 0; b < batches; ++b) {
+ for (int h = 0; h < input_height; ++h) {
+ for (int w = 0; w < input_width; ++w) {
+ // (h_start, h_end) * (w_start, w_end) is the range that the input
+ // vector projects to.
+ int hpad = h + pad_height;
+ int wpad = w + pad_width;
+ int h_start =
+ (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1;
+ int h_end = std::min(hpad / stride_height + 1, output_height);
+ int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1;
+ int w_end = std::min(wpad / stride_width + 1, output_width);
+ // compute elementwise sum
+ for (int ph = h_start; ph < h_end; ++ph) {
+ for (int pw = w_start; pw < w_end; ++pw) {
+ int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
+ out_mat.col(out_offset) =
+ out_mat.col(out_offset)
+ .cwiseMax(in_mat.col(
+ NodeOffset(b, h, w, input_height, input_width)));
+ }
+ }
+ }
+ }
+ }
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < output_height; ++y) {
+ for (int x = 0; x < output_width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ ActivationFunctionWithMinMax(
+ output_data[Offset(output_dims, c, x, y, b)],
+ output_activation_min, output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("MaxPool/8bit");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ // 2048 required by Inception v3
+ static constexpr int kAccBufferMaxSize = 2048;
+ TFLITE_DCHECK_LE(depth, kAccBufferMaxSize);
+ uint8 acc[kAccBufferMaxSize];
+ memset(acc, 0, depth * sizeof(acc[0]));
+ const uint8* input_ptr =
+ input_data + input_dims.strides[1] * in_x_origin +
+ input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch;
+ for (int fy = filter_y_start; fy < filter_y_end; fy++) {
+ const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] +
+ filter_x_start * input_dims.strides[1];
+ for (int fx = filter_x_start; fx < filter_x_end; fx++) {
+ int channel = 0;
+#ifdef USE_NEON
+ for (; channel <= depth - 16; channel += 16) {
+ uint8x16_t acc_reg = vld1q_u8(acc + channel);
+ uint8x16_t input_reg = vld1q_u8(input_row_ptr);
+ input_row_ptr += 16;
+ acc_reg = vmaxq_u8(acc_reg, input_reg);
+ vst1q_u8(acc + channel, acc_reg);
+ }
+
+ for (; channel <= depth - 8; channel += 8) {
+ uint8x8_t acc_reg = vld1_u8(acc + channel);
+ uint8x8_t input_reg = vld1_u8(input_row_ptr);
+ input_row_ptr += 8;
+ acc_reg = vmax_u8(acc_reg, input_reg);
+ vst1_u8(acc + channel, acc_reg);
+ }
+#endif
+ for (; channel < depth; ++channel) {
+ acc[channel] = std::max(acc[channel], *input_row_ptr++);
+ }
+ }
+ }
+ uint8* output_ptr =
+ output_data + Offset(output_dims, 0, out_x, out_y, batch);
+ int channel = 0;
+#ifdef USE_NEON
+ for (; channel <= depth - 16; channel += 16) {
+ uint8x16_t a = vld1q_u8(acc + channel);
+ a = vminq_u8(a, vdupq_n_u8(output_activation_max));
+ a = vmaxq_u8(a, vdupq_n_u8(output_activation_min));
+ vst1q_u8(output_ptr + channel, a);
+ }
+ for (; channel <= depth - 8; channel += 8) {
+ uint8x8_t a = vld1_u8(acc + channel);
+ a = vmin_u8(a, vdup_n_u8(output_activation_max));
+ a = vmax_u8(a, vdup_n_u8(output_activation_min));
+ vst1_u8(output_ptr + channel, a);
+ }
+#endif
+ for (; channel < depth; ++channel) {
+ uint8 a = acc[channel];
+ a = std::max<uint8>(a, output_activation_min);
+ a = std::min<uint8>(a, output_activation_max);
+ output_ptr[channel] = static_cast<uint8>(a);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("L2Pool");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ // Actually carry out L2 Pool. Code is written in forward mode: we go through
+ // the input values once, and write to all the pooled regions that it maps to.
+ const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ Eigen::VectorXf in_square(in_mat.rows());
+ Eigen::VectorXf out_count(out_mat.cols());
+ out_count.setZero();
+ // Prefill the output to 0.
+ out_mat.setZero();
+ for (int b = 0; b < batches; ++b) {
+ for (int h = 0; h < input_height; ++h) {
+ for (int w = 0; w < input_width; ++w) {
+ // (h_start, h_end) * (w_start, w_end) is the range that the input
+ // vector projects to.
+ const int hpad = h + pad_height;
+ const int wpad = w + pad_width;
+ const int h_start = (hpad < filter_height)
+ ? 0
+ : (hpad - filter_height) / stride_height + 1;
+ const int h_end = std::min(hpad / stride_height + 1, output_height);
+ const int w_start = (wpad < filter_width)
+ ? 0
+ : (wpad - filter_width) / stride_width + 1;
+ const int w_end = std::min(wpad / stride_width + 1, output_width);
+ // pre-compute square
+ const int in_offset = w + input_width * (h + input_height * b);
+ in_square =
+ in_mat.col(in_offset).array() * in_mat.col(in_offset).array();
+ // compute elementwise sum of squares
+ for (int ph = h_start; ph < h_end; ++ph) {
+ for (int pw = w_start; pw < w_end; ++pw) {
+ const int out_offset = pw + output_width * (ph + output_height * b);
+ out_mat.col(out_offset) += in_square;
+ out_count(out_offset)++;
+ }
+ }
+ }
+ }
+ }
+
+ out_count = out_count.array().inverse();
+ out_mat =
+ (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void LocalResponseNormalization(const float* input_data,
+ const Dims<4>& input_dims, int range,
+ float bias, float alpha, float beta,
+ float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization");
+ /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3);
+ /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2);
+ /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1);
+ /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ const auto data_in = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ auto data_out = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+
+ // Carry out local response normalization, vector by vector.
+ // Since the data are stored column major, making row-wise operation
+ // probably not memory efficient anyway, we do an explicit for loop over
+ // the columns.
+ const int double_range = range * 2;
+ Eigen::VectorXf padded_square(data_in.rows() + double_range);
+ padded_square.setZero();
+ for (int r = 0; r < data_in.cols(); ++r) {
+ // Do local response normalization for data_in(:, r)
+ // first, compute the square and store them in buffer for repeated use
+ padded_square.block(range, 0, data_in.rows(), 1) =
+ data_in.col(r).cwiseProduct(data_in.col(r)) * alpha;
+ // Then, compute the scale and writes them to data_out
+ float accumulated_scale = 0;
+ for (int i = 0; i < double_range; ++i) {
+ accumulated_scale += padded_square(i);
+ }
+ for (int i = 0; i < data_in.rows(); ++i) {
+ accumulated_scale += padded_square(i + double_range);
+ data_out(i, r) = bias + accumulated_scale;
+ accumulated_scale -= padded_square(i);
+ }
+ }
+
+ // In a few cases, the pow computation could benefit from speedups.
+ if (beta == 1) {
+ data_out.array() = data_in.array() * data_out.array().inverse();
+ } else if (beta == 0.5) {
+ data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
+ } else {
+ data_out.array() = data_in.array() * data_out.array().pow(-beta);
+ }
+}
+
+inline void Softmax(const float* input_data, const Dims<4>& input_dims,
+ float beta, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Softmax");
+ /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3);
+ /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2);
+ /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1);
+ /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ // Compute the exponential first, removing the max coefficient for numerical
+ // stability.
+ out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta;
+ // We are separating out the exp function so that exp can be vectorized.
+ out_mat = out_mat.array().exp();
+ // Normalize to get the activations.
+ Eigen::Array<float, 1, Eigen::Dynamic> scale =
+ out_mat.array().colwise().sum().inverse();
+ out_mat.array().rowwise() *= scale;
+}
+
+inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const Dims<4>& output_dims) {
+ // The representation chosen for the input to the exp() function is Q5.26.
+ // We need to leave extra space since values that we skip might be as large as
+ // -32 before multiplying by input_beta_multiplier, and therefore as large as
+ // -16 afterwards. Note that exp(-8) is definitely not insignificant to
+ // accumulation, but exp(-16) definitely is.
+ static const int kScaledDiffIntegerBits = 5;
+ static const int kAccumulationIntegerBits = 12;
+ using FixedPointScaledDiff =
+ gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
+ using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+
+ gemmlowp::ScopedProfilingLabel label("Softmax");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int x = 0; x < width; ++x) {
+ for (int y = 0; y < height; ++y) {
+ uint8 max_in_row = 0;
+ for (int c = 0; c < depth; ++c) {
+ max_in_row =
+ std::max(max_in_row, input_data[Offset(input_dims, c, x, y, b)]);
+ }
+
+ FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
+ for (int c = 0; c < depth; ++c) {
+ int32 input_diff =
+ static_cast<int32>(input_data[Offset(input_dims, c, x, y, b)]) -
+ max_in_row;
+ if (input_diff >= diff_min) {
+ const int32 input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+ sum_of_exps =
+ sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
+ exp_on_negative_values(scaled_diff_f8));
+ }
+ }
+
+ int32 fixed_sum_of_exps = sum_of_exps.raw();
+ // TODO(starka): Use a NEON intrinsic like vclzq_u32 instead.
+ int headroom_plus_one =
+ __builtin_clz(static_cast<uint32>(fixed_sum_of_exps));
+ // This is the number of bits to the left of the binary point above 1.0.
+ // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
+ // no later adjustment will be needed.
+ int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
+ int32 shifted_sum_minus_one = static_cast<int32>(
+ (static_cast<uint32>(fixed_sum_of_exps) << headroom_plus_one) -
+ (static_cast<uint32>(1) << 31));
+
+ FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
+ FixedPoint0::FromRaw(shifted_sum_minus_one));
+
+ for (int c = 0; c < depth; ++c) {
+ int32 input_diff =
+ static_cast<int32>(input_data[Offset(input_dims, c, x, y, b)]) -
+ max_in_row;
+ if (input_diff >= diff_min) {
+ const int32 input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+
+ FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
+ int32 unsat_output = gemmlowp::RoundingDivideByPOT(
+ (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
+
+ output_data[Offset(output_dims, c, x, y, b)] =
+ std::max(std::min(unsat_output, 255), 0);
+
+ } else {
+ output_data[Offset(output_dims, c, x, y, b)] = 0;
+ }
+ }
+ }
+ }
+ }
+}
+
+inline void Logistic(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Logistic");
+ auto input_map = MapAsVector(input_data, input_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ output_map.array() =
+ input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>());
+}
+
+inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Logistic");
+ /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3);
+ /* height */ MatchingArraySize(input_dims, 2, output_dims, 2);
+ /* width */ MatchingArraySize(input_dims, 1, output_dims, 1);
+ /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int size = RequiredBufferSizeForDims(input_dims);
+
+ int c = 0;
+#ifdef USE_NEON
+ // Handle 16 values at a time
+ for (; c <= size - 16; c += 16) {
+ // Read input uint8 values, cast to int16 and subtract input_zero_point
+ uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
+ int16x8_t input_val_centered_0 =
+ vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
+ vdupq_n_s16(input_zero_point));
+ int16x8_t input_val_centered_1 =
+ vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
+ vdupq_n_s16(input_zero_point));
+
+ // Prepare the bit masks that we will use at the end to implement the logic
+ // that was expressed in the scalar code with branching:
+ // if (input_val_centered < -input_range_radius) {
+ // output_val = 0;
+ // } else if (input_val_centered > input_range_radius) {
+ // output_val = 255;
+ // } else {
+ // ...
+ uint16x8_t mask_rightclamp_0 =
+ vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
+ uint16x8_t mask_rightclamp_1 =
+ vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
+ uint16x8_t mask_leftclamp_0 =
+ vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
+ uint16x8_t mask_leftclamp_1 =
+ vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
+ uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
+ vshrn_n_u16(mask_rightclamp_1, 8));
+ uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
+ vshrn_n_u16(mask_leftclamp_1, 8));
+
+ // This performs what is expressed in the scalar code as
+ // const int32 input_val_rescaled =
+ // MultiplyByQuantizedMultiplierGreaterThanOne(
+ // input_val_centered, input_multiplier, input_left_shift);
+ int32x4_t input_val_rescaled_0 =
+ vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
+ vdupq_n_s32(input_left_shift));
+ int32x4_t input_val_rescaled_1 =
+ vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
+ vdupq_n_s32(input_left_shift));
+ int32x4_t input_val_rescaled_2 =
+ vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
+ vdupq_n_s32(input_left_shift));
+ int32x4_t input_val_rescaled_3 =
+ vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
+ vdupq_n_s32(input_left_shift));
+ input_val_rescaled_0 =
+ vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
+ input_val_rescaled_1 =
+ vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
+ input_val_rescaled_2 =
+ vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
+ input_val_rescaled_3 =
+ vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
+
+ // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t
+ using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
+ const FixedPoint4 input_val_f4_0 =
+ FixedPoint4::FromRaw(input_val_rescaled_0);
+ const FixedPoint4 input_val_f4_1 =
+ FixedPoint4::FromRaw(input_val_rescaled_1);
+ const FixedPoint4 input_val_f4_2 =
+ FixedPoint4::FromRaw(input_val_rescaled_2);
+ const FixedPoint4 input_val_f4_3 =
+ FixedPoint4::FromRaw(input_val_rescaled_3);
+ const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
+ const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
+ const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
+ const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
+
+ // Divide by 2^23 as in the scalar code
+ using gemmlowp::RoundingDivideByPOT;
+ int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23);
+ int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23);
+ int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23);
+ int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23);
+
+ // Cast output values to uint8, saturating
+ int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
+ vqmovn_s32(output_val_s32_1));
+ int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
+ vqmovn_s32(output_val_s32_3));
+ uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
+ vqmovun_s16(output_val_s16_1));
+
+ // Perform the bit-masking with the bit masks computed at the beginning,
+ // see the comment there.
+ output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
+ output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
+
+ // Store back to memory
+ vst1q_u8(output_data + c, output_val_u8);
+ }
+#endif
+ // Leftover loop: handle one value at a time with scalar code.
+ for (; c < size; ++c) {
+ const uint8 input_val_u8 = input_data[c];
+ const int32 input_val_centered =
+ static_cast<int32>(input_val_u8) - input_zero_point;
+ uint8 output_val;
+ if (input_val_centered < -input_range_radius) {
+ output_val = 0;
+ } else if (input_val_centered > input_range_radius) {
+ output_val = 255;
+ } else {
+ const int32 input_val_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_val_centered, input_multiplier, input_left_shift);
+ using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+ const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
+ const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
+ using gemmlowp::RoundingDivideByPOT;
+ int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
+ if (output_val_s32 == 256) {
+ output_val_s32 = 255;
+ }
+ TFLITE_DCHECK_GE(output_val_s32, 0);
+ TFLITE_DCHECK_LE(output_val_s32, 255);
+ output_val = static_cast<uint8>(output_val_s32);
+ }
+ output_data[c] = output_val;
+ }
+}
+
+inline void Tanh(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Tanh");
+ auto input_map = MapAsVector(input_data, input_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ output_map.array() = input_map.array().tanh();
+}
+
+inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
+ int32 zero_point, double scale, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Dequantize");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ int32 val = input_data[Offset(input_dims, c, x, y, b)];
+ float result = static_cast<float>(scale * (val - zero_point));
+ output_data[Offset(output_dims, c, x, y, b)] = result;
+ }
+ }
+ }
+ }
+}
+
+inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
+ float rmin, float rmax, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("FakeQuant");
+
+ // 0 should always be a representable value. Let's assume that the initial
+ // min,max range contains 0.
+ TFLITE_DCHECK_LE(rmin, 0.);
+ TFLITE_DCHECK_GE(rmax, 0.);
+
+ // Determine quantization parameters: zero_point, scale.
+ using Integer = uint8;
+ const Integer qmin = std::numeric_limits<Integer>::min();
+ const Integer qmax = std::numeric_limits<Integer>::max();
+ const float qmin_float = qmin;
+ const float qmax_float = qmax;
+ int32 zero_point = 0;
+ float scale = 0.f;
+ // If rmin==rmax, both must be zero per the above assertion,
+ // so we are done.
+ if (rmin != rmax) {
+ // First determine the scale.
+ scale = (rmax - rmin) / (qmax_float - qmin_float);
+
+ // Zero-point computation.
+ // First the initial floating-point computation. The zero-point can be
+ // determined from solving an affine equation for any known pair
+ // (real value, corresponding quantized value).
+ // We know two such pairs: (rmin, qmin) and (rmax, qmax).
+ // The arithmetic error on the zero point computed from either pair
+ // will be roughly machine_epsilon * (sum of absolute values of terms)
+ // so we want to use the variant that adds the smaller terms.
+ const float zero_point_from_min = qmin_float - rmin / scale;
+ const float zero_point_from_max = qmax_float - rmax / scale;
+ const float zero_point_from_min_error =
+ std::abs(qmin_float) + std::abs(rmin / scale);
+ const float zero_point_from_max_error =
+ std::abs(qmax_float) + std::abs(rmax / scale);
+
+ const float zero_point_float =
+ zero_point_from_min_error < zero_point_from_max_error
+ ? zero_point_from_min
+ : zero_point_from_max;
+
+ // Now we need to nudge the zero point to be an integer
+ // (our zero points are integer, and this is motivated by the requirement
+ // to be able to represent the real value "0" exactly as a quantized value,
+ // which is required in multiple places, for example in Im2col with SAME
+ // padding).
+ if (zero_point_float < qmin_float) {
+ zero_point = qmin;
+ } else if (zero_point_float > qmax_float) {
+ zero_point = qmax;
+ } else {
+ zero_point = static_cast<int32>(TfLiteRound(zero_point_float));
+ }
+ // The zero point should always be in the range of quantized value,
+ // [qmin, qmax].
+ TFLITE_DCHECK_GE(zero_point, qmin);
+ TFLITE_DCHECK_LE(zero_point, qmax);
+ }
+
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ const float src_val = input_data[Offset(input_dims, c, x, y, b)];
+ const float unclamped_quantized_val =
+ TfLiteRound(zero_point + src_val / scale);
+ const float quantized_val = std::min(
+ qmax_float, std::max(qmin_float, unclamped_quantized_val));
+ const float dst_val = scale * (quantized_val - zero_point);
+ output_data[Offset(output_dims, c, x, y, b)] = dst_val;
+ }
+ }
+ }
+ }
+}
+
+template <typename SrcT, typename DstT>
+inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
+ DstT* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Cast");
+ auto input_map = MapAsVector(input_data, input_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ output_map.array() = input_map.array().template cast<DstT>();
+}
+
+inline void Floor(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Floor");
+ auto input_map = MapAsVector(input_data, input_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ output_map.array() = Eigen::floor(input_map.array());
+}
+
+template <typename T>
+inline void Gather(const T* input_data, const Dims<4>& input_dims,
+ int input_rank, const int32* coords_data,
+ const Dims<4>& coords_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Gather");
+
+ TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]);
+ int stride = input_dims.strides[input_rank - 1];
+ T* out = output_data;
+
+ for (int i = 0; i < coords_dims.sizes[0]; i++) {
+ TFLITE_DCHECK_GE(coords_data[i], 0);
+ TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]);
+ const T* in = input_data + coords_data[i] * stride;
+ memcpy(out, in, sizeof(T) * stride);
+ out += stride;
+ }
+}
+
+#ifdef USE_NEON
+inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
+ float scale, float* output_ptr) {
+ int ic = 0;
+ // Handle 32 input channels at a time.
+ for (; ic <= depth - 32; ic += 32) {
+ float32x4x2_t input[4];
+ for (int i = 0; i < 4; i++) {
+ input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
+ input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
+ }
+ float32x4x2_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
+ acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
+ }
+ for (int i = 0; i < 4; i++) {
+ acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
+ acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
+ }
+ for (int i = 0; i < 4; i++) {
+ vst1q_f32(output_ptr, acc[i].val[0]);
+ vst1q_f32(output_ptr + 4, acc[i].val[1]);
+ output_ptr += 8;
+ }
+ input_ptr += 32;
+ }
+ // Handle 16 input channels at a time.
+ for (; ic <= depth - 16; ic += 16) {
+ float32x4x2_t input[2];
+ for (int i = 0; i < 2; i++) {
+ input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
+ input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
+ }
+ float32x4x2_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
+ acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
+ }
+ for (int i = 0; i < 2; i++) {
+ acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
+ acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
+ }
+ for (int i = 0; i < 2; i++) {
+ vst1q_f32(output_ptr, acc[i].val[0]);
+ vst1q_f32(output_ptr + 4, acc[i].val[1]);
+ output_ptr += 8;
+ }
+ input_ptr += 16;
+ }
+ // Handle 8 input channels at a time.
+ for (; ic <= depth - 8; ic += 8) {
+ float32x4x2_t input;
+ input.val[0] = vld1q_f32(input_ptr);
+ input.val[1] = vld1q_f32(input_ptr + 4);
+
+ float32x4x2_t acc;
+ acc.val[0] = vld1q_f32(output_ptr);
+ acc.val[1] = vld1q_f32(output_ptr + 4);
+ acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale);
+ acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale);
+
+ vst1q_f32(output_ptr, acc.val[0]);
+ vst1q_f32(output_ptr + 4, acc.val[1]);
+
+ input_ptr += 8;
+ output_ptr += 8;
+ }
+ // Handle 4 input channels at a time.
+ for (; ic <= depth - 4; ic += 4) {
+ float32x4_t input = vld1q_f32(input_ptr);
+ float32x4_t acc = vld1q_f32(output_ptr);
+
+ acc = vmlaq_n_f32(acc, input, scale);
+ vst1q_f32(output_ptr, acc);
+
+ input_ptr += 4;
+ output_ptr += 4;
+ }
+ // Handle 1 input channel at a time.
+ for (; ic < depth; ic++) {
+ *output_ptr += *input_ptr * scale;
+ output_ptr++;
+ input_ptr++;
+ }
+}
+#else
+inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
+ float scale, float* output_ptr) {
+ for (int32 i = 0; i < depth; i++) {
+ *output_ptr += *input_ptr * scale;
+ output_ptr++;
+ input_ptr++;
+ }
+}
+#endif
+
+inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
+ int32 x, int32 y, int32 depth, int32 batch,
+ const float* input_data,
+ const Dims<4>& input_dims,
+ float* output_data,
+ const Dims<4>& output_dims) {
+ const int32 input_width = ArraySize(input_dims, 1);
+ const int32 output_width = ArraySize(output_dims, 1);
+
+ const int32 input_x_offset = (x1 - x0) * depth;
+ const int32 input_y_offset = (y1 - y0) * depth * input_width;
+ const int32 output_x_offset = depth;
+ const int32 output_y_offset = depth * output_width;
+
+#ifdef USE_NEON
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(x1 >= x0);
+ TFLITE_DCHECK(y1 >= y0);
+
+ int ic = 0;
+ // Handle 8 input channels at a time.
+ for (; ic <= depth - 8; ic += 8) {
+ const float* input_ptr = nullptr;
+
+ float32x4x2_t x0y0;
+ input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)];
+ x0y0.val[0] = vld1q_f32(input_ptr);
+ x0y0.val[1] = vld1q_f32(input_ptr + 4);
+
+ float32x4x2_t x1y0;
+ input_ptr += input_x_offset;
+ x1y0.val[0] = vld1q_f32(input_ptr);
+ x1y0.val[1] = vld1q_f32(input_ptr + 4);
+
+ float32x4x2_t x0y1;
+ input_ptr += -input_x_offset + input_y_offset;
+ x0y1.val[0] = vld1q_f32(input_ptr);
+ x0y1.val[1] = vld1q_f32(input_ptr + 4);
+
+ float32x4x2_t x1y1;
+ input_ptr += input_x_offset;
+ x1y1.val[0] = vld1q_f32(input_ptr);
+ x1y1.val[1] = vld1q_f32(input_ptr + 4);
+
+ // Top left corner.
+ float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)];
+ vst1q_f32(output_ptr, x0y0.val[0]);
+ vst1q_f32(output_ptr + 4, x0y0.val[1]);
+
+ // Top right corner.
+ output_ptr += output_x_offset;
+ float32x4x2_t tr;
+ tr.val[0] = vaddq_f32(x0y0.val[0], x1y0.val[0]);
+ tr.val[1] = vaddq_f32(x0y0.val[1], x1y0.val[1]);
+ tr.val[0] = vmulq_n_f32(tr.val[0], 0.5f);
+ tr.val[1] = vmulq_n_f32(tr.val[1], 0.5f);
+
+ vst1q_f32(output_ptr, tr.val[0]);
+ vst1q_f32(output_ptr + 4, tr.val[1]);
+
+ // Bottom left corner.
+ output_ptr += -output_x_offset + output_y_offset;
+ float32x4x2_t bl;
+ bl.val[0] = vaddq_f32(x0y0.val[0], x0y1.val[0]);
+ bl.val[1] = vaddq_f32(x0y0.val[1], x0y1.val[1]);
+ bl.val[0] = vmulq_n_f32(bl.val[0], 0.5f);
+ bl.val[1] = vmulq_n_f32(bl.val[1], 0.5f);
+ vst1q_f32(output_ptr, bl.val[0]);
+ vst1q_f32(output_ptr + 4, bl.val[1]);
+
+ // Bottom right corner.
+ output_ptr += output_x_offset;
+ float32x4x2_t br;
+ br.val[0] = vaddq_f32(x1y0.val[0], x1y1.val[0]);
+ br.val[1] = vaddq_f32(x1y0.val[1], x1y1.val[1]);
+ br.val[0] = vmlaq_n_f32(bl.val[0], br.val[0], 0.5f);
+ br.val[1] = vmlaq_n_f32(bl.val[1], br.val[1], 0.5f);
+ br.val[0] = vmulq_n_f32(br.val[0], 0.5f);
+ br.val[1] = vmulq_n_f32(br.val[1], 0.5f);
+ vst1q_f32(output_ptr, br.val[0]);
+ vst1q_f32(output_ptr + 4, br.val[1]);
+ }
+ // Handle 4 input channels at a time.
+ for (; ic <= depth - 4; ic += 4) {
+ const float* input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)];
+ float32x4_t x0y0 = vld1q_f32(input_ptr);
+ float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset);
+ float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset);
+ float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset);
+
+ // Top left corner.
+ float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)];
+ vst1q_f32(output_ptr, x0y0);
+
+ // Top right corner.
+ output_ptr += output_x_offset;
+ float32x4_t tr = vaddq_f32(x0y0, x1y0);
+ tr = vmulq_n_f32(tr, 0.5f);
+ vst1q_f32(output_ptr, tr);
+
+ // Bottom left corner.
+ output_ptr += -output_x_offset + output_y_offset;
+ float32x4_t bl = vaddq_f32(x0y0, x0y1);
+ bl = vmulq_n_f32(bl, 0.5f);
+ vst1q_f32(output_ptr, bl);
+
+ // Bottom right corner.
+ output_ptr += output_x_offset;
+ float32x4_t br = vaddq_f32(x1y0, x1y1);
+ br = vmlaq_n_f32(bl, br, 0.5f);
+ br = vmulq_n_f32(br, 0.5f);
+ vst1q_f32(output_ptr, br);
+ }
+ // Handle one input channel at a time.
+ for (; ic < depth; ic++) {
+ const int32 input_offset = Offset(input_dims, ic, x0, y0, batch);
+
+ float x0y0 = input_data[input_offset];
+ float x1y0 = input_data[input_offset + input_x_offset];
+ float x0y1 = input_data[input_offset + input_y_offset];
+ float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
+
+ // Top left corner.
+ const int32 output_offset = Offset(output_dims, ic, x, y, batch);
+ output_data[output_offset] = x0y0;
+
+ // Top right corner.
+ output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
+
+ // Bottom left corner.
+ float output = (x0y0 + x0y1) / 2;
+ output_data[output_offset + output_y_offset] = output;
+
+ // Bottom right corner.
+ output_data[output_offset + output_x_offset + output_y_offset] =
+ (output + ((x1y0 + x1y1) / 2)) / 2;
+ }
+#else
+ for (int ch = 0; ch < depth; ch++) {
+ const int32 input_offset = Offset(input_dims, ch, x0, y0, batch);
+
+ float x0y0 = input_data[input_offset];
+ float x1y0 = input_data[input_offset + input_x_offset];
+ float x0y1 = input_data[input_offset + input_y_offset];
+ float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
+
+ // Top left corner.
+ const int32 output_offset = Offset(output_dims, ch, x, y, batch);
+ output_data[output_offset] = x0y0;
+
+ // Top right corner.
+ output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
+
+ // Bottom left corner.
+ float output = (x0y0 + x0y1) / 2;
+ output_data[output_offset + output_y_offset] = output;
+
+ // Bottom right corner.
+ output_data[output_offset + output_x_offset + output_y_offset] =
+ (output + ((x1y0 + x1y1) / 2)) / 2;
+ }
+#endif
+}
+
+inline void ResizeBilinear2x2(const float* input_data,
+ const Dims<4>& input_dims, float* output_data,
+ const Dims<4>& output_dims, int32 batches,
+ int32 input_height, int32 input_width,
+ int32 depth, int32 output_height,
+ int32 output_width) {
+ for (int b = 0; b < batches; b++) {
+ for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) {
+ for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) {
+ int32 x1 = std::min(x0 + 1, input_width - 1);
+ int32 y1 = std::min(y0 + 1, input_height - 1);
+ ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_data,
+ input_dims, output_data, output_dims);
+ }
+ }
+ }
+}
+
+inline void ResizeBilinearGeneric(const float* input_data,
+ const Dims<4>& input_dims, float* output_data,
+ const Dims<4>& output_dims, int32 batches,
+ int32 input_height, int32 input_width,
+ int32 depth, int32 output_height,
+ int32 output_width, float height_scale,
+ float width_scale) {
+ memset(output_data, 0,
+ batches * output_height * output_width * depth * sizeof(float));
+
+ int32 output_offset = 0;
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < output_height; ++y) {
+ float input_y = y * height_scale;
+ int32 y0 = static_cast<int32>(std::floor(input_y));
+ int32 y1 = std::min(y0 + 1, input_height - 1);
+ for (int x = 0; x < output_width; ++x) {
+ float input_x = x * width_scale;
+ int32 x0 = static_cast<int32>(input_x);
+ int32 x1 = std::min(x0 + 1, input_width - 1);
+ float* output_ptr = &output_data[output_offset];
+
+ // Run kernel on the 4 corners of the bilinear resize algorithm.
+ int32 input_offset = Offset(input_dims, 0, x0, y0, b);
+ float scale = (1 - (input_y - y0)) * (1 - (input_x - x0));
+ const float* input_ptr = &input_data[input_offset];
+ ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
+
+ input_offset = Offset(input_dims, 0, x1, y0, b);
+ scale = (1 - (input_y - y0)) * (input_x - x0);
+ input_ptr = &input_data[input_offset];
+ ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
+
+ input_offset = Offset(input_dims, 0, x0, y1, b);
+ scale = (input_y - y0) * (1 - (input_x - x0));
+ input_ptr = &input_data[input_offset];
+ ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
+
+ input_offset = Offset(input_dims, 0, x1, y1, b);
+ scale = (input_y - y0) * (input_x - x0);
+ input_ptr = &input_data[input_offset];
+ ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
+
+ output_offset += depth;
+ }
+ }
+ }
+}
+
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
+ int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ int32 input_height = ArraySize(input_dims, 2);
+ int32 input_width = ArraySize(input_dims, 1);
+ int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
+ int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
+ int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
+
+ // Specialize for 2x2 upsample.
+ if (output_height == 2 * input_height && output_width == 2 * input_width) {
+ ResizeBilinear2x2(input_data, input_dims, output_data, output_dims, batches,
+ input_height, input_width, depth, output_height,
+ output_width);
+ } else {
+ float height_scale = static_cast<float>(input_height) / output_height;
+ float width_scale = static_cast<float>(input_width) / output_width;
+
+ ResizeBilinearGeneric(input_data, input_dims, output_data, output_dims,
+ batches, input_height, input_width, depth,
+ output_height, output_width, height_scale,
+ width_scale);
+ }
+}
+
+template <typename T>
+inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* paddings_data,
+ const Dims<4>& paddings_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("SpaceToBatchND");
+
+ const int output_batch_size = ArraySize(output_dims, 3);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int input_batch_size = ArraySize(input_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int depth = ArraySize(input_dims, 0);
+ const int block_shape_height = block_shape_data[0];
+ const int block_shape_width = block_shape_data[1];
+ const int padding_top = paddings_data[0];
+ const int padding_left = paddings_data[2];
+
+ for (int out_b = 0; out_b < output_batch_size; ++out_b) {
+ int input_batch = out_b % input_batch_size;
+ int shift_w = (out_b / input_batch_size) % block_shape_width;
+ int shift_h = (out_b / input_batch_size) / block_shape_width;
+ for (int out_h = 0; out_h < output_height; ++out_h) {
+ for (int out_w = 0; out_w < output_width; ++out_w) {
+ T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b);
+ if (out_h * block_shape_height < padding_top ||
+ out_h * block_shape_height >= padding_top + input_height ||
+ out_w * block_shape_width < padding_left ||
+ out_w * block_shape_width >= padding_left + input_width) {
+ memset(out, 0, depth * sizeof(T));
+ } else {
+ const T* in =
+ input_data +
+ Offset(input_dims, 0,
+ (out_w * block_shape_width + shift_w) - padding_left,
+ (out_h * block_shape_height + shift_h) - padding_top,
+ input_batch);
+ memcpy(out, in, depth * sizeof(T));
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BatchToSpaceND");
+
+ const int output_batch_size = ArraySize(output_dims, 3);
+ const int input_batch_size = ArraySize(input_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int depth = ArraySize(input_dims, 0);
+ const int block_shape_width = block_shape_data[1];
+ const int block_shape_height = block_shape_data[0];
+
+ for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ int out_batch = in_batch % output_batch_size;
+ int out_w = in_w * block_shape_width +
+ (in_batch / output_batch_size) % block_shape_width;
+ int out_h = in_h * block_shape_height +
+ (in_batch / output_batch_size) / block_shape_width;
+ T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch);
+ const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch);
+ memcpy(out, in, depth * sizeof(T));
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Pad");
+ const int output_batch = ArraySize(output_dims, 3);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_depth = ArraySize(output_dims, 0);
+
+ const int left_b_padding = left_paddings[3];
+ const int left_h_padding = left_paddings[2];
+ const int left_w_padding = left_paddings[1];
+ const int left_d_padding = left_paddings[0];
+
+ const int right_b_padding = right_paddings[3];
+ const int right_h_padding = right_paddings[2];
+ const int right_w_padding = right_paddings[1];
+ const int right_d_padding = right_paddings[0];
+
+ const int input_depth = ArraySize(input_dims, 0);
+
+ if (left_b_padding != 0) {
+ memset(output_data, 0,
+ left_b_padding * output_height * output_width * output_depth *
+ sizeof(T));
+ }
+ for (int out_b = left_b_padding; out_b < output_batch - right_b_padding;
+ ++out_b) {
+ if (left_h_padding != 0) {
+ memset(output_data + Offset(output_dims, 0, 0, 0, out_b), 0,
+ left_h_padding * output_width * output_depth * sizeof(T));
+ }
+ for (int out_h = left_h_padding; out_h < output_height - right_h_padding;
+ ++out_h) {
+ if (left_w_padding != 0) {
+ memset(output_data + Offset(output_dims, 0, 0, out_h, out_b), 0,
+ left_w_padding * output_depth * sizeof(T));
+ }
+ for (int out_w = left_w_padding; out_w < output_width - right_w_padding;
+ ++out_w) {
+ if (left_d_padding != 0) {
+ memset(output_data + Offset(output_dims, 0, out_w, out_h, out_b), 0,
+ left_d_padding * sizeof(T));
+ }
+
+ T* out = output_data +
+ Offset(output_dims, left_d_padding, out_w, out_h, out_b);
+ const T* in =
+ input_data + Offset(input_dims, 0, out_w - left_w_padding,
+ out_h - left_h_padding, out_b - left_b_padding);
+ memcpy(out, in, input_depth * sizeof(T));
+
+ if (right_d_padding != 0) {
+ memset(
+ output_data + Offset(output_dims, output_depth - right_d_padding,
+ out_w, out_h, out_b),
+ 0, right_d_padding * sizeof(T));
+ }
+ }
+ if (right_w_padding != 0) {
+ memset(
+ output_data + Offset(output_dims, 0, output_width - right_w_padding,
+ out_h, out_b),
+ 0, right_w_padding * output_depth * sizeof(T));
+ }
+ }
+ if (right_h_padding != 0) {
+ memset(output_data + Offset(output_dims, 0, 0,
+ output_height - right_h_padding, out_b),
+ 0, right_h_padding * output_width * output_depth * sizeof(T));
+ }
+ }
+ if (right_b_padding != 0) {
+ memset(output_data +
+ Offset(output_dims, 0, 0, 0, output_batch - right_b_padding),
+ 0,
+ right_b_padding * output_height * output_width * output_depth *
+ sizeof(T));
+ }
+}
+
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+ int begin_mask, int end_mask,
+ const std::vector<int>& starts,
+ const std::vector<int>& stops,
+ const std::vector<int>& strides, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("StridedSlice");
+ const int start_b = (begin_mask & 8) ? 0 : starts[3];
+ const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3];
+ const int start_h = (begin_mask & 4) ? 0 : starts[2];
+ const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2];
+ const int start_w = (begin_mask & 2) ? 0 : starts[1];
+ const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1];
+ const int start_d = (begin_mask & 1) ? 0 : starts[0];
+ const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0];
+
+ T* out_ptr = output_data;
+ if (strides[0] == 0) {
+ for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) {
+ for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) {
+ for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) {
+ const int len = stop_d - start_d;
+ memcpy(out_ptr,
+ input_data + Offset(input_dims, start_d, in_w, in_h, in_b),
+ len * sizeof(T));
+ out_ptr += len;
+ }
+ }
+ }
+ } else {
+ for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) {
+ for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) {
+ for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) {
+ for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) {
+ *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ }
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ // TODO(dkalenichenko): This op only supports 4D tensors.
+ TFLITE_DCHECK_EQ(begin.size(), 4);
+ TFLITE_DCHECK_EQ(size.size(), 4);
+ const int start_b = begin[3];
+ const int stop_b =
+ size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3];
+ const int start_h = begin[2];
+ const int stop_h =
+ size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2];
+ const int start_w = begin[1];
+ const int stop_w =
+ size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1];
+ const int start_d = begin[0];
+ const int stop_d =
+ size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0];
+
+ T* out_ptr = output_data;
+ for (int in_b = start_b; in_b < stop_b; ++in_b) {
+ for (int in_h = start_h; in_h < stop_h; ++in_h) {
+ for (int in_w = start_w; in_w < stop_w; ++in_w) {
+ const int len = stop_d - start_d;
+ memcpy(out_ptr,
+ input_data + Offset(input_dims, start_d, in_w, in_h, in_b),
+ len * sizeof(T));
+ out_ptr += len;
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void Mean(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& reduction_indices, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Mean");
+ const int output_batch = ArraySize(output_dims, 3);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_depth = ArraySize(output_dims, 0);
+
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+
+ // The current implementation only supports simultaneous reduction over
+ // width and height.
+ TFLITE_DCHECK_EQ(reduction_indices.size(), 2);
+ TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) ||
+ (reduction_indices[0] == 2 && reduction_indices[1] == 1));
+ TFLITE_DCHECK_EQ(output_height, 1);
+ TFLITE_DCHECK_EQ(output_width, 1);
+
+ for (int out_b = 0; out_b < output_batch; ++out_b) {
+ for (int out_d = 0; out_d < output_depth; ++out_d) {
+ float value = 0;
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)];
+ }
+ }
+ output_data[Offset(output_dims, out_d, 0, 0, out_b)] =
+ value / (input_width * input_height);
+ }
+ }
+}
+
+template <typename T>
+void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("GenericBroadcastSub");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
+ const Dims<4>& input2_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Sub");
+
+ auto input1_map = MapAsVector(input1_data, input1_dims);
+ auto input2_map = MapAsVector(input2_data, input2_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ if (AreSameDims(input1_dims, input2_dims)) {
+ output_map.array() = input1_map.array() - input2_map.array();
+ } else if (RequiredBufferSizeForDims(input1_dims) == 1) {
+ auto scalar = input1_data[0];
+ output_map.array() = scalar - input2_map.array();
+ } else if (RequiredBufferSizeForDims(input2_dims) == 1) {
+ auto scalar = input2_data[0];
+ output_map.array() = input1_map.array() - scalar;
+ } else {
+ GenericBroadcastSub(input1_data, input1_dims, input2_data, input2_dims,
+ output_data, output_dims);
+ }
+}
+
+template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("TensorFlowMinimum");
+ auto input1_map = MapAsVector(input1_data, input1_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ auto min_value = input2_data[0];
+ output_map.array() = input1_map.array().min(min_value);
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("TensorFlowMaximum");
+ auto input1_map = MapAsVector(input1_data, input1_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ auto max_value = input2_data[0];
+ output_map.array() = input1_map.array().max(max_value);
+}
+} // namespace optimized_ops
+} // namespace tflite
+
+#if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
+#undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
+#pragma GCC diagnostic pop
+#endif
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
new file mode 100644
index 0000000000..f8be99e82f
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -0,0 +1,138 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_
+#define TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_
+
+// TDOD(ghodrat): Remove this header file and the dependency to internal data
+// structure.
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+
+#ifndef USE_NEON
+#if defined(__ARM_NEON__) || defined(__ARM_NEON)
+#define USE_NEON
+#endif // defined(__ARM_NEON__) || defined(__ARM_NEON)
+#endif // USE_NEON
+
+namespace tflite {
+namespace tensor_utils {
+
+// Multiply a matrix by a batch vector, and store results in a batch-size
+// vector.
+void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
+ int m_rows, int m_cols,
+ const float* vector,
+ int n_batch, float* result,
+ int result_stride);
+void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
+ int m_cols, const float* vector,
+ int n_batch, float* result,
+ int result_stride);
+
+// Cwise product of two vectors.
+void PortableVectorVectorCwiseProduct(const float* vector1,
+ const float* vector2, int v_size,
+ float* result);
+void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2,
+ int v_size, float* result);
+
+// Cwise product and accumulate of two vectors. Since it's a MAC operation, the
+// assumption here is that result array is initialized to valid values.
+void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2,
+ int v_size, float* result);
+void NeonVectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2, int v_size,
+ float* result);
+
+// Dot product of two vectors.
+float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size);
+float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size);
+
+// Dot product of two batch vectors.
+void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride);
+void NeonBatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride);
+
+// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
+// operation, the assumption here is that result array is initialized to valid
+// values.
+void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
+ int v_size,
+ const float* batch_vector,
+ int n_batch,
+ float* result);
+void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
+ int v_size,
+ const float* batch_vector,
+ int n_batch, float* result);
+
+// Compute "1.0f - elements of vector" (used in CIFG).
+void PortableSub1Vector(const float* vector, int v_size, float* result);
+void NeonSub1Vector(const float* vector, int v_size, float* result);
+
+// Clip elements of a vector using a abs_limit value.
+void PortableClipVector(const float* vector, int v_size, float abs_limit,
+ float* result);
+void NeonClipVector(const float* vector, int v_size, float abs_limit,
+ float* result);
+
+// Batch vector initialization with another vector.
+void PortableVectorBatchVectorAssign(const float* vector, int v_size,
+ int n_batch, float* batch_vector);
+
+// Apply sigmoid to elements of a vector.
+void PortableApplySigmoidToVector(const float* vector, int v_size,
+ float* result);
+
+// Apply activation function to elements of a vector.
+void PortableApplyActivationToVector(const float* vector, int v_size,
+ TfLiteFusedActivation activation,
+ float* result);
+
+// Copy vector to another vector.
+void PortableCopyVector(const float* vector, int v_size, float* result);
+
+// Fill vector with 0.f.
+void PortableZeroVector(float* vector, int v_size);
+
+// Limit a float input f between +abs_limit and -abs_limit.
+float PortableClip(float f, float abs_limit);
+
+// Shift left a vector in place with v_size size.
+void PortableVectorShiftLeft(float* vector, int v_size, float shift_value);
+void NeonVectorShiftLeft(float* vector, int v_size, float shift_value);
+
+// Reduce-sum on a float input vector:
+// input_vector: float pointer to input vector.
+// output_vector: float pointer to vector.
+// output_size: output vector size.
+// reduction_size: number of consecutive elements from input vector which are
+// added to get one element of output.
+void PortableReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size);
+void NeonReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size);
+
+} // namespace tensor_utils
+} // namespace tflite
+
+#endif // TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
new file mode 100644
index 0000000000..98f2e365c5
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
@@ -0,0 +1,95 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <limits>
+
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+
+namespace tflite {
+
+void QuantizeMultiplierSmallerThanOne(double double_multiplier,
+ int32_t* quantized_multiplier,
+ int* right_shift) {
+ TFLITE_CHECK(double_multiplier >= 0.);
+ TFLITE_CHECK(double_multiplier < 1.);
+ if (double_multiplier == 0.) {
+ *quantized_multiplier = 0;
+ *right_shift = 0;
+ return;
+ }
+ TFLITE_CHECK(double_multiplier > 0.);
+ const double q = std::frexp(double_multiplier, right_shift);
+ *right_shift *= -1;
+
+ auto q_fixed = static_cast<int64_t>(TfLiteRound(q * (1ll << 31)));
+ TFLITE_CHECK(q_fixed <= (1ll << 31));
+ if (q_fixed == (1ll << 31)) {
+ q_fixed /= 2;
+ --*right_shift;
+ }
+ TFLITE_CHECK_GE(*right_shift, 0);
+ TFLITE_CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
+ *quantized_multiplier = static_cast<int32_t>(q_fixed);
+}
+
+void QuantizeMultiplierGreaterThanOne(double double_multiplier,
+ int32_t* quantized_multiplier,
+ int* left_shift) {
+ TFLITE_CHECK(double_multiplier > 1.);
+ const double q = std::frexp(double_multiplier, left_shift);
+ auto q_fixed = static_cast<int64_t>(TfLiteRound(q * (1ll << 31)));
+ TFLITE_CHECK(q_fixed <= (1ll << 31));
+ if (q_fixed == (1ll << 31)) {
+ q_fixed /= 2;
+ ++*left_shift;
+ }
+ TFLITE_CHECK_GE(*left_shift, 0);
+ TFLITE_CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
+ *quantized_multiplier = static_cast<int32_t>(q_fixed);
+}
+
+void PreprocessSoftmaxScaling(double beta, double input_scale,
+ int input_integer_bits,
+ int32_t* quantized_multiplier, int* left_shift) {
+ // If the overall multiplier (input and beta) is large, then exp() of an
+ // input difference of 1 scaled by this will be large. In other words, we
+ // can cap the multiplier and know that, when it is used, the output will be
+ // (round to) zero wherever the input is not at the maximum value.
+
+ // If the overall scale is less than one, and input_integer_bits=0, then the
+ // result is double equivalent of Q0.31 (actually with more precision). Thus
+ // this generates a Q(input_integer_bits).(31-input_integer_bits)
+ // representation.
+ const double input_beta_real_multiplier = std::min(
+ beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0);
+
+ QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier,
+ quantized_multiplier, left_shift);
+}
+
+int CalculateInputRadius(int input_integer_bits, int input_left_shift) {
+ const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) *
+ (1ll << (31 - input_integer_bits)) /
+ (1ll << input_left_shift);
+ // Tighten bound using floor. Suppose that we could use the exact value.
+ // After scaling the difference, the result would be at the maximum. Thus we
+ // must ensure that our value has lower magnitude.
+ return static_cast<int>(std::floor(max_input_rescaled));
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
new file mode 100644
index 0000000000..efb7191c8d
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_
+#define PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_
+
+#include <cstdint>
+
+namespace tflite {
+
+// Decompose a double multiplier into a Q0.31 int32 representation of its
+// significand, and shift representation of its exponent.
+//
+// Restricted to the case where the multiplier < 1 (and non-negative).
+void QuantizeMultiplierSmallerThanOne(double double_multiplier,
+ int32_t* quantized_multiplier,
+ int* right_shift);
+
+// Decompose a double multiplier into a Q0.31 int32 representation of its
+// significand, and shift representation of its exponent.
+//
+// Restricted to the case where the multiplier > 1.
+void QuantizeMultiplierGreaterThanOne(double double_multiplier,
+ int32_t* quantized_multiplier,
+ int* left_shift);
+
+// This first creates a multiplier in a double equivalent of
+// Q(input_integer_bits).(31-input_integer_bits) representation, with extra
+// precision in the double's fractional bits. It then splits the result into
+// significand and exponent.
+void PreprocessSoftmaxScaling(double beta, double input_scale,
+ int input_integer_bits,
+ int32_t* quantized_multiplier, int* left_shift);
+
+// Calculate the largest input that will result in a within-bounds intermediate
+// result within MultiplyByQuantizedMultiplierGreaterThanOne. In other words,
+// it must not overflow before we reduce the value by multiplication by the
+// input multiplier. The negative radius is used as the minimum difference
+// in Softmax.
+int CalculateInputRadius(int input_integer_bits, int input_left_shift);
+
+} // namespace tflite
+
+#endif // PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
new file mode 100644
index 0000000000..d6f306e2cb
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
@@ -0,0 +1,108 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace {
+
+using ::testing::Pair;
+
+TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) {
+ auto quantize = [](double d) {
+ int32_t q;
+ int s;
+ QuantizeMultiplierSmallerThanOne(d, &q, &s);
+ return std::pair<int32_t, int>{q, s};
+ };
+
+ EXPECT_DEATH(quantize(-0.1), "");
+ EXPECT_THAT(quantize(0.0), Pair(0, 0));
+ EXPECT_THAT(quantize(0.25), Pair(1073741824, 1));
+
+ // Around 0.5 we can see the change in exponent and how we try hard to
+ // void hitting max int32.
+ EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, 1));
+ EXPECT_THAT(quantize(0.50 - 1e-10), Pair(1073741824, 0));
+ EXPECT_THAT(quantize(0.50), Pair(1073741824, 0));
+
+ EXPECT_THAT(quantize(0.75), Pair(1610612736, 0));
+ EXPECT_THAT(quantize(1 - 1e-9), Pair(2147483646, 0));
+
+ // If we get close enough to 1.0 it crashes and dies in one of two ways:
+ // Either the shift becomes negative or we trigger the 'less-than-one' CHECK.
+ EXPECT_DEATH(quantize(1 - 1e-15), "");
+ EXPECT_DEATH(quantize(1 - 1e-17), "");
+ EXPECT_DEATH(quantize(1.0), "");
+}
+
+TEST(QuantizationUtilTest, QuantizeMultiplierGreaterThanOne) {
+ auto quantize = [](double d) {
+ int32_t q;
+ int s;
+ QuantizeMultiplierGreaterThanOne(d, &q, &s);
+ return std::pair<int32_t, int>{q, s};
+ };
+
+ // If we are close enough to 1.0 it crashes.
+ EXPECT_DEATH(quantize(1 + 1e-16), "");
+
+ EXPECT_THAT(quantize(1 + 1e-11), Pair(1073741824, 1));
+ EXPECT_THAT(quantize(1.25), Pair(1342177280, 1));
+ EXPECT_THAT(quantize(1.50), Pair(1610612736, 1));
+ EXPECT_THAT(quantize(1.75), Pair(1879048192, 1));
+
+ // Around the powers of two we see the change in exponent. Also,
+ // we try hard to avoid hitting max int32.
+ EXPECT_THAT(quantize(2 - 1e-9), Pair(2147483647, 1));
+ EXPECT_THAT(quantize(2 - 1e-11), Pair(1073741824, 2));
+ EXPECT_THAT(quantize(2), Pair(1073741824, 2));
+}
+
+TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) {
+ auto quantize = [](double beta, double scale, int integer_bits) {
+ int32_t q;
+ int s;
+ PreprocessSoftmaxScaling(beta, scale, integer_bits, &q, &s);
+ return std::pair<int32_t, int>{q, s};
+ };
+
+ // If beta * scale is greater than fits in the number of integer bits, the
+ // result is move near the maximum. Otherwise they quantize as expected.
+ // With 4 integer bits we can represent up to 16.0.
+ EXPECT_THAT(quantize(1.0, 16.0, 4), Pair(2147483647, 31));
+ EXPECT_THAT(quantize(1.0, 8.0, 4), Pair(1073741824, 31));
+ // But with 5 bits we can go further.
+ EXPECT_THAT(quantize(2.0, 16.0, 5), Pair(2147483647, 31));
+ EXPECT_THAT(quantize(2.0, 8.0, 5), Pair(1073741824, 31));
+}
+
+TEST(QuantizationUtilTest, CalculateInputRadius) {
+ EXPECT_EQ(CalculateInputRadius(4, 27), 15);
+ EXPECT_EQ(CalculateInputRadius(3, 27), 14);
+ EXPECT_EQ(CalculateInputRadius(3, 28), 7);
+ EXPECT_EQ(CalculateInputRadius(4, 2), 503316480);
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
new file mode 100644
index 0000000000..8e0f234545
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
@@ -0,0 +1,115 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_
+
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace reference_ops {
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int ic = 0; ic < input_depth; ++ic) {
+ for (int m = 0; m < depth_multiplier; m++) {
+ const int oc = m + ic * depth_multiplier;
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ float total = 0.f;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ // If the location is outside the bounds of the input image,
+ // use zero as a default value.
+ if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+ (in_y < input_height)) {
+ float input_value =
+ input_data[Offset(input_dims, ic, in_x, in_y, b)];
+ float filter_value = filter_data[Offset(
+ filter_dims, oc, filter_x, filter_y, 0)];
+ total += (input_value * filter_value);
+ }
+ }
+ }
+ float bias_value = 0.0f;
+ if (bias_data) {
+ bias_value = bias_data[Offset(bias_dims, oc, 0, 0, 0)];
+ }
+ output_data[Offset(output_dims, oc, out_x, out_y, b)] =
+ ActivationFunctionWithMinMax(total + bias_value,
+ output_activation_min,
+ output_activation_max);
+ }
+ }
+ }
+ }
+ }
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, pad_width, pad_height,
+ depth_multiplier, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ float* output_data, const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, pad_width, pad_height,
+ depth_multiplier, output_data, output_dims);
+}
+
+} // end namespace reference_ops
+} // end namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
new file mode 100644
index 0000000000..8a80558b32
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -0,0 +1,138 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
+
+#include <algorithm>
+
+#include "fixedpoint/fixedpoint.h"
+#include "public/gemmlowp.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace reference_ops {
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int ic = 0; ic < input_depth; ++ic) {
+ for (int m = 0; m < depth_multiplier; m++) {
+ const int oc = m + ic * depth_multiplier;
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ int32 acc = 0;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ // If the location is outside the bounds of the input image,
+ // use zero as a default value.
+ if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+ (in_y < input_height)) {
+ int32 input_val =
+ input_data[Offset(input_dims, ic, in_x, in_y, b)];
+ int32 filter_val = filter_data[Offset(filter_dims, oc,
+ filter_x, filter_y, 0)];
+ acc +=
+ (filter_val + filter_offset) * (input_val + input_offset);
+ }
+ }
+ }
+ if (bias_data) {
+ acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)];
+ }
+ acc = MultiplyByQuantizedMultiplierSmallerThanOne(
+ acc, output_multiplier, output_shift);
+ acc += output_offset;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_data[Offset(output_dims, oc, out_x, out_y, b)] =
+ static_cast<uint8>(acc);
+ }
+ }
+ }
+ }
+ }
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
+ filter_dims, filter_offset, bias_data, bias_dims, stride,
+ stride, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+} // end namespace reference_ops
+} // end namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
new file mode 100644
index 0000000000..c5b0bccc9d
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -0,0 +1,165 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace tensor_utils {
+
+float PortableClip(float f, float abs_limit) {
+ float result = (abs_limit < f) ? abs_limit : f;
+ result = (-abs_limit > result) ? -abs_limit : result;
+ return result;
+}
+
+void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
+ int m_rows, int m_cols,
+ const float* vector,
+ int n_batch, float* result,
+ int result_stride) {
+ float* result_in_batch = result;
+ for (int b = 0; b < n_batch; b++) {
+ const float* matrix_ptr = matrix;
+ for (int r = 0; r < m_rows; r++) {
+ const float* vector_in_batch = vector + b * m_cols;
+ for (int c = 0; c < m_cols; c++) {
+ *result_in_batch += *matrix_ptr++ * *vector_in_batch++;
+ }
+ result_in_batch += result_stride;
+ }
+ }
+}
+
+void PortableVectorVectorCwiseProduct(const float* vector1,
+ const float* vector2, int v_size,
+ float* result) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ = *vector1++ * *vector2++;
+ }
+}
+
+float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size) {
+ float result = 0.0;
+ for (int v = 0; v < v_size; v++) {
+ result += *vector1++ * *vector2++;
+ }
+ return result;
+}
+
+void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride) {
+ float* result_ptr = result;
+ const float* vector1_ptr = vector1;
+ const float* vector2_ptr = vector2;
+ for (int b = 0; b < n_batch; b++) {
+ *result_ptr =
+ PortableVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size);
+ vector1_ptr += v_size;
+ vector2_ptr += v_size;
+ result_ptr += result_stride;
+ }
+}
+
+void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2,
+ int v_size, float* result) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ += *vector1++ * *vector2++;
+ }
+}
+
+void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
+ int v_size,
+ const float* batch_vector,
+ int n_batch,
+ float* result) {
+ for (int b = 0; b < n_batch; b++) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ += vector[v] * *batch_vector++;
+ }
+ }
+}
+
+void PortableVectorBatchVectorAssign(const float* vector, int v_size,
+ int n_batch, float* batch_vector) {
+ for (int b = 0; b < n_batch; b++) {
+ memcpy(batch_vector + b * v_size, vector, v_size * sizeof(float));
+ }
+}
+
+void PortableApplySigmoidToVector(const float* vector, int v_size,
+ float* result) {
+ auto sigmoid_func = ActivationFunctor(kTfLiteActSigmoid);
+ for (int v = 0; v < v_size; v++) {
+ *result++ = (sigmoid_func)(*vector++);
+ }
+}
+
+void PortableApplyActivationToVector(const float* vector, int v_size,
+ TfLiteFusedActivation activation,
+ float* result) {
+ auto activation_func = ActivationFunctor(activation);
+ for (int v = 0; v < v_size; v++) {
+ *result++ = (activation_func)(*vector++);
+ }
+}
+
+void PortableCopyVector(const float* vector, int v_size, float* result) {
+ memcpy(result, vector, v_size * sizeof(float));
+}
+
+void PortableSub1Vector(const float* vector, int v_size, float* result) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ = 1.0f - *vector++;
+ }
+}
+
+void PortableZeroVector(float* vector, int v_size) {
+ memset(vector, 0, v_size * sizeof(float));
+}
+
+void PortableClipVector(const float* vector, int v_size, float abs_limit,
+ float* result) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ = PortableClip(*vector++, abs_limit);
+ }
+}
+
+void PortableVectorShiftLeft(float* vector, int v_size, float shift_value) {
+ TF_LITE_ASSERT(v_size > 0);
+ for (int i = 0; i < v_size - 1; i++) {
+ vector[i] = vector[i + 1];
+ }
+ vector[v_size - 1] = shift_value;
+}
+
+void PortableReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size) {
+ const float* input_vector_ptr = input_vector;
+ for (int o = 0; o < output_size; o++) {
+ for (int r = 0; r < reduction_size; r++) {
+ output_vector[o] += *input_vector_ptr++;
+ }
+ }
+}
+
+} // namespace tensor_utils
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
new file mode 100644
index 0000000000..c2ab78000b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -0,0 +1,189 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
+
+// TDOD(ghodrat): Remove this header file and the dependency to internal data
+// structure.
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+
+namespace tflite {
+namespace tensor_utils {
+
+// Limit a float input f betweeen +abs_limit and -abs_limit.
+float PortableClip(float f, float abs_limit);
+
+// Multiply a matrix by a batch vector, and store results in a batch-size
+// vector.
+void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
+ int m_rows, int m_cols,
+ const float* vector,
+ int n_batch, float* result,
+ int result_stride);
+
+// Cwise product of two vectors.
+void PortableVectorVectorCwiseProduct(const float* vector1,
+ const float* vector2, int v_size,
+ float* result);
+
+// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the
+// assumption here is that result array is initialized to valid values.
+void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2,
+ int v_size, float* result);
+
+// Dot product of two vectors.
+float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size);
+
+// Dot product of two batch vectors.
+void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride);
+
+// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
+// operation, the assumption here is that result array is initialized to valid
+// values.
+void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
+ int v_size,
+ const float* batch_vector,
+ int n_batch,
+ float* result);
+
+// Batch vector initialization with another vector.
+void PortableVectorBatchVectorAssign(const float* vector, int v_size,
+ int n_batch, float* batch_vector);
+
+// Apply sigmoid to elements of a vector.
+void PortableApplySigmoidToVector(const float* vector, int v_size,
+ float* result);
+
+// Apply activation function to elements of a vector.
+void PortableApplyActivationToVector(const float* vector, int v_size,
+ TfLiteFusedActivation activation,
+ float* result);
+
+// Copy vector to another vector.
+void PortableCopyVector(const float* vector, int v_size, float* result);
+
+// Compute "1.0f - elements of vector" (used in CIFG).
+void PortableSub1Vector(const float* vector, int v_size, float* result);
+
+// Fill vector with 0.f.
+void PortableZeroVector(float* vector, int v_size);
+
+// Clip elements of a vector using a abs_limit value.
+void PortableClipVector(const float* vector, int v_size, float abs_limit,
+ float* result);
+
+// Shift left a vector in place with v_size size.
+void PortableVectorShiftLeft(float* vector, int v_size, float shift_value);
+
+// Reduce-sum on a float input vector:
+// input_vector: float pointer to input vector.
+// output_vector: float pointer to vector.
+// output_size: output vector size.
+// reduction_size: number of consecutive elements from input vector which are
+// added to get one element of output.
+void PortableReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size);
+
+float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
+
+void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
+ int m_cols, const float* vector,
+ int n_batch, float* result,
+ int result_stride) {
+ PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
+ n_batch, result, result_stride);
+}
+
+void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
+ int v_size, float* result) {
+ PortableVectorVectorCwiseProduct(vector1, vector2, v_size, result);
+}
+
+void VectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2, int v_size,
+ float* result) {
+ PortableVectorVectorCwiseProductAccumulate(vector1, vector2, v_size, result);
+}
+
+void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result) {
+ PortableVectorBatchVectorCwiseProductAccumulate(vector, v_size, batch_vector,
+ n_batch, result);
+}
+
+float VectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size) {
+ return PortableVectorVectorDotProduct(vector1, vector2, v_size);
+}
+
+void BatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride) {
+ PortableBatchVectorBatchVectorDotProduct(vector1, vector2, v_size, n_batch,
+ result, result_stride);
+}
+
+void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
+}
+
+void ApplySigmoidToVector(const float* vector, int v_size, float* result) {
+ PortableApplySigmoidToVector(vector, v_size, result);
+}
+
+void ApplyActivationToVector(const float* vector, int v_size,
+ TfLiteFusedActivation activation, float* result) {
+ PortableApplyActivationToVector(vector, v_size, activation, result);
+}
+
+void CopyVector(const float* vector, int v_size, float* result) {
+ PortableCopyVector(vector, v_size, result);
+}
+
+void Sub1Vector(const float* vector, int v_size, float* result) {
+ PortableSub1Vector(vector, v_size, result);
+}
+
+void ZeroVector(float* vector, int v_size) {
+ PortableZeroVector(vector, v_size);
+}
+
+void ClipVector(const float* vector, int v_size, float abs_limit,
+ float* result) {
+ PortableClipVector(vector, v_size, abs_limit, result);
+}
+
+void VectorShiftLeft(float* vector, int v_size, float shift_value) {
+ PortableVectorShiftLeft(vector, v_size, shift_value);
+}
+
+void ReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size) {
+ PortableReductionSumVector(input_vector, output_vector, output_size,
+ reduction_size);
+}
+
+} // namespace tensor_utils
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
new file mode 100644
index 0000000000..b9ca3d5c62
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -0,0 +1,2455 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
+
+#include <stdint.h>
+#include <sys/types.h>
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <memory>
+#include <type_traits>
+
+#include "third_party/eigen3/Eigen/Core"
+#include "fixedpoint/fixedpoint.h"
+#include "public/gemmlowp.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace reference_ops {
+
+inline int32 MultiplyByQuantizedMultiplierSmallerThanOne(
+ int32 x, int32 quantized_multiplier, int right_shift) {
+ using gemmlowp::RoundingDivideByPOT;
+ using gemmlowp::SaturatingRoundingDoublingHighMul;
+ return RoundingDivideByPOT(
+ SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift);
+}
+
+inline int32 MultiplyByQuantizedMultiplierGreaterThanOne(
+ int32 x, int32 quantized_multiplier, int left_shift) {
+ using gemmlowp::SaturatingRoundingDoublingHighMul;
+ return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
+ quantized_multiplier);
+}
+
+template <typename T>
+int CountLeadingZeros(T integer_input) {
+ static_assert(std::is_unsigned<T>::value,
+ "Only unsigned integer types handled.");
+ const T one_in_leading_positive = static_cast<T>(1)
+ << (std::numeric_limits<T>::digits - 1);
+ int leading_zeros = 0;
+ while (integer_input < one_in_leading_positive) {
+ integer_input <<= 1;
+ ++leading_zeros;
+ }
+ return leading_zeros;
+}
+
+// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
+// BROADCASTING.
+//
+// NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
+// rectangular array of numbers.
+//
+// NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
+// However, as Dims<N> is to be deprecated, this class exists as an adaptor
+// to enable simple unoptimized implementations of element-wise broadcasting
+// operations.
+template <int N>
+struct NdArrayDesc {
+ // The "extent" of each dimension. Indices along dimension d must be in the
+ // half-open interval [0, extents[d]).
+ int extents[N];
+
+ // The number of *elements* (not bytes) between consecutive indices of each
+ // dimension.
+ int strides[N];
+};
+
+// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
+// ELEMENT-WISE BROADCASTING.
+//
+// Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
+inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
+ int i3) {
+ TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
+ TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
+ TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
+ TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
+ return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
+ i3 * desc.strides[3];
+}
+
+// Given the dimensions of the operands for an element-wise binary broadcast,
+// adjusts them so that they can be directly iterated over with simple loops.
+// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
+// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
+//
+// This function assumes that the two input shapes are compatible up to
+// broadcasting and the shorter one has already been prepended with 1s to be the
+// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
+// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
+// Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
+// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
+//
+// When two shapes are compatible up to broadcasting, for each dimension d,
+// the input extents are either equal, or one of them is 1.
+//
+// This function performs the following for each dimension d:
+// - If the extents are equal, then do nothing since the loop that walks over
+// both of the input arrays is correct.
+// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
+// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
+// array0 to be referenced *at any index* in dimension d and still access the
+// same slice.
+template <int N>
+inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
+ const Dims<N>& input1_dims,
+ NdArrayDesc<N>* desc0_out,
+ NdArrayDesc<N>* desc1_out) {
+ TFLITE_DCHECK(desc0_out != nullptr);
+ TFLITE_DCHECK(desc1_out != nullptr);
+
+ // Copy dims to desc.
+ for (int i = 0; i < N; ++i) {
+ desc0_out->extents[i] = input0_dims.sizes[i];
+ desc0_out->strides[i] = input0_dims.strides[i];
+ desc1_out->extents[i] = input1_dims.sizes[i];
+ desc1_out->strides[i] = input1_dims.strides[i];
+ }
+
+ // Walk over each dimension. If the extents are equal do nothing.
+ // Otherwise, set the desc with extent 1 to have extent equal to the other and
+ // stride 0.
+ for (int i = 0; i < N; ++i) {
+ const int extent0 = ArraySize(input0_dims, i);
+ const int extent1 = ArraySize(input1_dims, i);
+ if (extent0 != extent1) {
+ if (extent0 == 1) {
+ desc0_out->strides[i] = 0;
+ desc0_out->extents[i] = extent1;
+ } else {
+ TFLITE_DCHECK_EQ(extent1, 1);
+ desc1_out->strides[i] = 0;
+ desc1_out->extents[i] = extent0;
+ }
+ }
+ }
+}
+
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ (void)im2col_data; // only used in optimized code.
+ (void)im2col_dims; // only used in optimized code.
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
+ const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
+ if (bias_data) {
+ TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0));
+ }
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ float total = 0.f;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ // If the location is outside the bounds of the input image,
+ // use zero as a default value.
+ if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+ (in_y < input_height)) {
+ float input_value = input_data[Offset(input_dims, in_channel,
+ in_x, in_y, batch)];
+ float filter_value =
+ filter_data[Offset(filter_dims, in_channel, filter_x,
+ filter_y, out_channel)];
+ total += (input_value * filter_value);
+ }
+ }
+ }
+ }
+ float bias_value = 0.0f;
+ if (bias_data) {
+ bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
+ }
+ output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
+ ActivationFunctionWithMinMax(total + bias_value,
+ output_activation_min,
+ output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, pad_width, pad_height,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, pad_width, pad_height, output_data,
+ output_dims, im2col_data, im2col_dims);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ (void)im2col_data; // only used in optimized code.
+ (void)im2col_dims; // only used in optimized code.
+ (void)gemm_context; // only used in optimized code.
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
+ const int output_depth =
+ MatchingArraySize(filter_dims, 3, bias_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ int32 acc = 0;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ // If the location is outside the bounds of the input image,
+ // use zero as a default value.
+ if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+ (in_y < input_height)) {
+ int32 input_val = input_data[Offset(input_dims, in_channel,
+ in_x, in_y, batch)];
+ int32 filter_val =
+ filter_data[Offset(filter_dims, in_channel, filter_x,
+ filter_y, out_channel)];
+ acc +=
+ (filter_val + filter_offset) * (input_val + input_offset);
+ }
+ }
+ }
+ }
+ if (bias_data) {
+ acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
+ }
+ acc = MultiplyByQuantizedMultiplierSmallerThanOne(
+ acc, output_multiplier, output_shift);
+ acc += output_offset;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
+ static_cast<uint8>(acc);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
+ Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, stride, pad_width,
+ pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims, im2col_data, im2col_dims, gemm_context);
+}
+
+template <typename T>
+inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ const int input_depth = ArraySize(input_dims, 0);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_batch = ArraySize(input_dims, 3);
+
+ const int output_depth = ArraySize(output_dims, 0);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_batch = ArraySize(output_dims, 3);
+
+ TFLITE_DCHECK_EQ(input_width * block_size, output_width);
+ TFLITE_DCHECK_EQ(input_height * block_size, output_height);
+ TFLITE_DCHECK_EQ(input_depth, output_depth * block_size * block_size);
+ TFLITE_DCHECK_EQ(input_batch, output_batch);
+
+ for (int out_b = 0; out_b < output_batch; ++out_b) {
+ for (int out_h = 0; out_h < output_height; ++out_h) {
+ for (int out_w = 0; out_w < output_width; ++out_w) {
+ for (int out_d = 0; out_d < output_depth; ++out_d) {
+ const int in_d =
+ out_d + ((out_h % block_size) * block_size + out_w % block_size) *
+ output_depth;
+ const int in_w = out_w / block_size;
+ const int in_h = out_h / block_size;
+ const int in_b = out_b;
+
+ const int output_index =
+ Offset(output_dims, out_d, out_w, out_h, out_b);
+ const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b);
+
+ output_data[output_index] = input_data[input_index];
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ const int input_depth = ArraySize(input_dims, 0);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_batch = ArraySize(input_dims, 3);
+
+ const int output_depth = ArraySize(output_dims, 0);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_batch = ArraySize(output_dims, 3);
+
+ TFLITE_DCHECK_EQ(input_width, output_width * block_size);
+ TFLITE_DCHECK_EQ(input_height, output_height * block_size);
+ TFLITE_DCHECK_EQ(input_depth * block_size * block_size, output_depth);
+ TFLITE_DCHECK_EQ(input_batch, output_batch);
+
+ for (int in_b = 0; in_b < input_batch; ++in_b) {
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ for (int in_d = 0; in_d < input_depth; ++in_d) {
+ const int out_d =
+ in_d + ((in_h % block_size) * block_size + in_w % block_size) *
+ input_depth;
+ const int out_w = in_w / block_size;
+ const int out_h = in_h / block_size;
+ const int out_b = in_b;
+
+ const int output_index =
+ Offset(output_dims, out_d, out_w, out_h, out_b);
+ const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b);
+
+ output_data[output_index] = input_data[input_index];
+ }
+ }
+ }
+ }
+}
+
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
+ ArraySize(output_dims, 3);
+ const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
+ const int accum_depth = ArraySize(weights_dims, 0);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ float total = 0.f;
+ for (int d = 0; d < accum_depth; ++d) {
+ total += input_data[b * accum_depth + d] *
+ weights_data[out_c * accum_depth + d];
+ }
+ float bias_value = 0.0f;
+ if (bias_data) {
+ bias_value = bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
+ }
+ output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax(
+ total + bias_value, output_activation_min, output_activation_max);
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data, const Dims<4>& weights_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
+ bias_dims, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ (void)gemm_context; // only used in optimized code.
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
+ ArraySize(output_dims, 3);
+ const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
+ const int accum_depth = ArraySize(filter_dims, 0);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ int32 acc = 0;
+ for (int d = 0; d < accum_depth; ++d) {
+ int32 input_val = input_data[b * accum_depth + d];
+ int32 filter_val = filter_data[out_c * accum_depth + d];
+ acc += (filter_val + filter_offset) * (input_val + input_offset);
+ }
+ if (bias_data) {
+ acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
+ }
+ acc = MultiplyByQuantizedMultiplierSmallerThanOne(acc, output_multiplier,
+ output_shift);
+ acc += output_offset;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_data[out_c + output_depth * b] = static_cast<uint8>(acc);
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims, gemm_context);
+}
+
+template <FusedActivationFunctionType Ac>
+void NonGlobalBatchNormalization(
+ const float* input_data, const Dims<4>& input_dims, const float* mean_data,
+ const Dims<4>& mean_dims, const float* multiplier_data,
+ const Dims<4>& multiplier_dims, const float* offset_data,
+ const Dims<4>& offset_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height =
+ MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2,
+ offset_dims, 2, output_dims, 2);
+ const int width =
+ MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1,
+ offset_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
+ offset_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ (input_data[Offset(input_dims, c, x, y, b)] -
+ mean_data[Offset(mean_dims, c, x, y, 0)]) *
+ multiplier_data[Offset(multiplier_dims, c, x, y, 0)] +
+ offset_data[Offset(offset_dims, c, x, y, 0)]);
+ }
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+void GlobalBatchNormalization(const float* input_data,
+ const Dims<4>& input_dims, const float* mean_data,
+ const Dims<4>& mean_dims,
+ const float* multiplier_data,
+ const Dims<4>& multiplier_dims,
+ const float* offset_data,
+ const Dims<4>& offset_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
+ offset_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ (input_data[Offset(input_dims, c, x, y, b)] -
+ mean_data[Offset(mean_dims, c, 0, 0, 0)]) *
+ multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] +
+ offset_data[Offset(offset_dims, c, 0, 0, 0)]);
+ }
+ }
+ }
+ }
+}
+
+inline void Relu(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ const float lower = 0;
+ float clamped = val < lower ? lower : val;
+ output_data[Offset(output_dims, c, x, y, b)] = clamped;
+ }
+ }
+ }
+ }
+}
+
+inline void Relu1(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ const float upper = 1;
+ const float lower = -1;
+ float clamped = val > upper ? upper : val < lower ? lower : val;
+ output_data[Offset(output_dims, c, x, y, b)] = clamped;
+ }
+ }
+ }
+ }
+}
+
+inline void Relu6(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ const float upper = 6;
+ const float lower = 0;
+ float clamped = val > upper ? upper : val < lower ? lower : val;
+ output_data[Offset(output_dims, c, x, y, b)] = clamped;
+ }
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone, "");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ float squared_l2_norm = 0;
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ squared_l2_norm += val * val;
+ }
+ float l2_norm = std::sqrt(squared_l2_norm);
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ input_data[Offset(input_dims, c, x, y, b)] / l2_norm;
+ }
+ }
+ }
+ }
+}
+
+inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt,
+ int* output_shift) {
+ *output_shift = 11;
+ while (input >= (1 << 29)) {
+ input /= 4;
+ ++*output_shift;
+ }
+ TFLITE_DCHECK_GT(input, 0);
+ const unsigned max_left_shift_bits = __builtin_clz(input) - 1;
+ const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
+ const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
+ *output_shift -= left_shift_bit_pairs;
+ input <<= 2 * left_shift_bit_pairs;
+ TFLITE_DCHECK_GE(input, (1 << 27));
+ TFLITE_DCHECK_LT(input, (1 << 29));
+ using gemmlowp::FixedPoint;
+ using gemmlowp::Rescale;
+ using gemmlowp::SaturatingRoundingMultiplyByPOT;
+ // Using 3 integer bits gives us enough room for the internal arithmetic in
+ // this Newton-Raphson iteration.
+ using F3 = FixedPoint<int32, 3>;
+ using F0 = FixedPoint<int32, 0>;
+ const F3 fixedpoint_input = F3::FromRaw(input >> 1);
+ const F3 fixedpoint_half_input =
+ SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
+ const F3 fixedpoint_half_three =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
+ // Newton-Raphson iteration
+ // Naive unoptimized starting guess: x = 1
+ F3 x = F3::One();
+ // Naive unoptimized number of iterations: 5
+ for (int i = 0; i < 5; i++) {
+ const F3 x3 = Rescale<3>(x * x * x);
+ x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
+ }
+ const F0 fixedpoint_half_sqrt_2 =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
+ x = x * fixedpoint_half_sqrt_2;
+ *output_inv_sqrt = x.raw();
+ if (*output_shift < 0) {
+ *output_inv_sqrt <<= -*output_shift;
+ *output_shift = 0;
+ }
+}
+
+inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, uint8* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ TFLITE_DCHECK_EQ(batches, 1);
+ TFLITE_DCHECK_EQ(height, 1);
+ TFLITE_DCHECK_EQ(width, 1);
+ int32 square_l2_norm = 0;
+ for (int i = 0; i < depth; i++) {
+ int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point;
+ square_l2_norm += diff * diff;
+ }
+ int32 inv_l2norm_multiplier;
+ int inv_l2norm_shift;
+ GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier,
+ &inv_l2norm_shift);
+
+ for (int i = 0; i < depth; i++) {
+ int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point;
+ int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne(
+ 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
+ int32 unclamped_output_val = 128 + rescaled_diff;
+ int32 output_val = std::min(255, std::max(0, unclamped_output_val));
+ output_data[Offset(output_dims, i, 0, 0, 0)] =
+ static_cast<uint8>(output_val);
+ }
+}
+
+inline void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches =
+ MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
+ const int height =
+ MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
+ const int width =
+ MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ ActivationFunctionWithMinMax(
+ input1_data[Offset(input1_dims, c, x, y, b)] +
+ input2_data[Offset(input2_dims, c, x, y, b)],
+ output_activation_min, output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier, int input2_shift,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ const int batches =
+ MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
+ const int height =
+ MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
+ const int width =
+ MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ const int32 input1_val =
+ input1_offset + input1_data[Offset(input1_dims, c, x, y, b)];
+ const int32 input2_val =
+ input2_offset + input2_data[Offset(input2_dims, c, x, y, b)];
+ const int32 shifted_input1_val = input1_val * (1 << left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input1_val, input1_multiplier, input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input2_val, input2_multiplier, input2_shift);
+ const int32 raw_sum = scaled_input1_val + scaled_input2_val;
+ const int32 raw_output =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ raw_sum, output_multiplier, output_shift) +
+ output_offset;
+ const int32 clamped_output =
+ std::min(output_activation_max,
+ std::max(output_activation_min, raw_output));
+ output_data[Offset(output_dims, c, x, y, b)] =
+ static_cast<uint8>(clamped_output);
+ }
+ }
+ }
+ }
+}
+
+// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
+// dimensionality if the runtime code does a single loop over one dimension
+// that handles broadcasting as the base case. The code generator would then
+// generate max(D1, D2) nested for loops.
+template <FusedActivationFunctionType Ac>
+void BroadcastAdd(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ const int32 input1_val =
+ input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ const int32 input2_val =
+ input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ const int32 shifted_input1_val = input1_val * (1 << left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input1_val, input1_multiplier, input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input2_val, input2_multiplier, input2_shift);
+ const int32 raw_sum = scaled_input1_val + scaled_input2_val;
+ const int32 raw_output =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ raw_sum, output_multiplier, output_shift) +
+ output_offset;
+ const int32 clamped_output =
+ std::min(output_activation_max,
+ std::max(output_activation_min, raw_output));
+ output_data[Offset(output_dims, c, x, y, b)] =
+ static_cast<uint8>(clamped_output);
+ }
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset,
+ input1_multiplier, input1_shift, input2_data, input2_dims,
+ input2_offset, input2_multiplier, input2_shift, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches =
+ MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
+ const int height =
+ MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
+ const int width =
+ MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ ActivationFunctionWithMinMax(
+ input1_data[Offset(input1_dims, c, x, y, b)] *
+ input2_data[Offset(input2_dims, c, x, y, b)],
+ output_activation_min, output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
+// dimensionality if the runtime code does a single loop over one dimension
+// that handles broadcasting as the base case. The code generator would then
+// generate max(D1, D2) nested for loops.
+template <FusedActivationFunctionType Ac>
+void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMul");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest
+ // stride, typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for
+ // the best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest
+ // stride, typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for
+ // the best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ const int32 input1_val =
+ input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ const int32 input2_val =
+ input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ const int32 unclamped_result =
+ output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ input1_val * input2_val, output_multiplier, output_shift);
+ const int32 clamped_output =
+ std::min(output_activation_max,
+ std::max(output_activation_min, unclamped_result));
+ output_data[Offset(output_dims, c, x, y, b)] =
+ static_cast<uint8>(clamped_output);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
+ input2_dims, input2_offset, output_offset, output_multiplier,
+ output_shift, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void Concatenation(int concat_dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ TFLITE_DCHECK_GT(inputs_count, 1);
+ int concat_size = 0;
+ for (int i = 0; i < inputs_count; i++) {
+ for (int j = 0; j < 4; j++) {
+ if (j != concat_dim) {
+ MatchingArraySize(*input_dims[i], j, output_dims, j);
+ }
+ }
+ concat_size += ArraySize(*input_dims[i], concat_dim);
+ }
+ TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ int outer_size = 1;
+ for (int i = concat_dim + 1; i < 4; i++) {
+ outer_size *= output_dims.sizes[i];
+ }
+ Scalar* output_ptr = output_data;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < inputs_count; ++i) {
+ const int copy_size =
+ input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
+ memcpy(output_ptr, input_data[i] + k * copy_size,
+ copy_size * sizeof(Scalar));
+ output_ptr += copy_size;
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void DepthConcatenation(const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ Concatenation<Ac, Scalar>(0, input_data, input_dims, inputs_count,
+ output_data, output_dims);
+}
+
+inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
+ const float* prev_activ_data,
+ const Dims<4>& prev_activ_dims, const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims, const float* prev_state_data,
+ const Dims<4>& prev_state_dims, float* output_state_data,
+ const Dims<4>& output_state_dims, float* output_activ_data,
+ const Dims<4>& output_activ_dims, float* concat_temp_data,
+ const Dims<4>& concat_temp_dims, float* activ_temp_data,
+ const Dims<4>& activ_temp_dims) {
+ const int batches =
+ MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3,
+ output_state_dims, 3, output_activ_dims, 3);
+ const int height =
+ MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2,
+ output_state_dims, 2, output_activ_dims, 2);
+ const int width =
+ MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1,
+ output_state_dims, 1, output_activ_dims, 1);
+ TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
+ TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ const int total_input_depth = prev_activ_depth + input_depth;
+ TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
+ TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
+ 1);
+ const int intern_activ_depth =
+ MatchingArraySize(weights_dims, 1, bias_dims, 0);
+ TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ const int output_depth =
+ MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
+ output_state_dims, 0, output_activ_dims, 0);
+ TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
+
+ // Concatenate prev_activ and input data together
+ std::vector<float const*> concat_input_arrays_data;
+ std::vector<Dims<4> const*> concat_input_arrays_dims;
+ concat_input_arrays_data.push_back(input_data);
+ concat_input_arrays_data.push_back(prev_activ_data);
+ concat_input_arrays_dims.push_back(&input_dims);
+ concat_input_arrays_dims.push_back(&prev_activ_dims);
+ Concatenation<FusedActivationFunctionType::kNone, float>(
+ 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]),
+ concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims);
+
+ // Fully connected
+ FullyConnected<FusedActivationFunctionType::kNone>(
+ concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data,
+ bias_dims, activ_temp_data, activ_temp_dims);
+
+ // Memory state update (the LSTM "guts")
+ for (int b = 0; b < batches; ++b) {
+ for (int w = 0; w < width; ++w) {
+ for (int h = 0; h < height; ++h) {
+ for (int c = 0; c < output_depth; ++c) {
+ const float input_gate =
+ 1.f /
+ (1.f + std::exp(-activ_temp_data[Offset(
+ activ_temp_dims, 0 * output_depth + c, w, h, b)]));
+ const float new_input = std::tanh(activ_temp_data[Offset(
+ activ_temp_dims, 1 * output_depth + c, w, h, b)]);
+ const float forget_gate =
+ 1.f /
+ (1.f + std::exp(-activ_temp_data[Offset(
+ activ_temp_dims, 2 * output_depth + c, w, h, b)]));
+ const float output_gate =
+ 1.f /
+ (1.f + std::exp(-activ_temp_data[Offset(
+ activ_temp_dims, 3 * output_depth + c, w, h, b)]));
+ const float new_state =
+ input_gate * new_input +
+ forget_gate *
+ prev_state_data[Offset(prev_state_dims, c, w, h, b)];
+ output_state_data[Offset(output_state_dims, c, w, h, b)] = new_state;
+ output_activ_data[Offset(output_activ_dims, c, w, h, b)] =
+ output_gate * std::tanh(new_state);
+ }
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
+ int outputs_count, Scalar* const* output_data,
+ const Dims<4>* const* output_dims) {
+ TFLITE_DCHECK_GE(outputs_count, 1);
+ for (int i = 0; i < outputs_count; i++) {
+ /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
+ /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
+ /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
+ }
+ const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3);
+ const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2);
+ const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1);
+ // for now we dont have a model with a TensorFlowSplit
+ // with fused activation function.
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ int in_c = 0;
+ for (int i = 0; i < outputs_count; ++i) {
+ const int depth = ArraySize(*output_dims[i], 0);
+ for (int c = 0; c < depth; ++c) {
+ output_data[i][Offset(*output_dims[i], c, x, y, b)] =
+ input_data[Offset(input_dims, in_c, x, y, b)];
+ in_c++;
+ }
+ }
+ TFLITE_DCHECK(in_c == ArraySize(input_dims, 0));
+ }
+ }
+ }
+}
+
+// TODO(benoitjacob) make this a proper reference impl without Eigen!
+template <typename Scalar>
+using MatrixMap = typename std::conditional<
+ std::is_const<Scalar>::value,
+ Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
+ Eigen::Dynamic, Eigen::Dynamic>>,
+ Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
+
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
+ const Dims<N>& dims) {
+ const int rows = dims.sizes[0];
+ int cols = 1;
+ for (int d = 1; d < N; d++) {
+ cols *= dims.sizes[d];
+ }
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
+ const Dims<N>& dims) {
+ const int cols = dims.sizes[N - 1];
+ int rows = 1;
+ for (int d = 0; d < N - 1; d++) {
+ rows *= dims.sizes[d];
+ }
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+inline int NodeOffset(int b, int h, int w, int height, int width) {
+ return (b * height + h) * width + w;
+}
+
+inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int channel = 0; channel < depth; ++channel) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ // Compute the boundaries of the filter region clamped so as to
+ // ensure that the filter window fits in the input array.
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ float total = 0.f;
+ float filter_count = 0;
+ for (int filter_y = filter_y_start; filter_y < filter_y_end;
+ ++filter_y) {
+ for (int filter_x = filter_x_start; filter_x < filter_x_end;
+ ++filter_x) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ total +=
+ input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ filter_count++;
+ }
+ }
+ const float average = total / filter_count;
+ output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ ActivationFunctionWithMinMax(average, output_activation_min,
+ output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int channel = 0; channel < depth; ++channel) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ // Compute the boundaries of the filter region clamped so as to
+ // ensure that the filter window fits in the input array.
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ int32 acc = 0;
+ int filter_count = 0;
+ for (int filter_y = filter_y_start; filter_y < filter_y_end;
+ ++filter_y) {
+ for (int filter_x = filter_x_start; filter_x < filter_x_end;
+ ++filter_x) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ acc += input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ filter_count++;
+ }
+ }
+ acc = (acc + filter_count / 2) / filter_count;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ static_cast<uint8>(acc);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int channel = 0; channel < depth; ++channel) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ // Compute the boundaries of the filter region clamped so as to
+ // ensure that the filter window fits in the input array.
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ float sum_squares = 0.f;
+ int filter_count = 0;
+ for (int filter_y = filter_y_start; filter_y < filter_y_end;
+ ++filter_y) {
+ for (int filter_x = filter_x_start; filter_x < filter_x_end;
+ ++filter_x) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ const float val =
+ input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ sum_squares += val * val;
+ filter_count++;
+ }
+ }
+ const float l2pool_result = std::sqrt(sum_squares / filter_count);
+ output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ ActivationFunctionWithMinMax(l2pool_result, output_activation_min,
+ output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int channel = 0; channel < depth; ++channel) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ // Compute the boundaries of the filter region clamped so as to
+ // ensure that the filter window fits in the input array.
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ float max = std::numeric_limits<float>::lowest();
+ for (int filter_y = filter_y_start; filter_y < filter_y_end;
+ ++filter_y) {
+ for (int filter_x = filter_x_start; filter_x < filter_x_end;
+ ++filter_x) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ max = std::max(
+ max,
+ input_data[Offset(input_dims, channel, in_x, in_y, batch)]);
+ }
+ }
+ output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ ActivationFunctionWithMinMax(max, output_activation_min,
+ output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ TFLITE_DCHECK_GE(output_activation_min, 0);
+ TFLITE_DCHECK_LE(output_activation_max, 255);
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int channel = 0; channel < depth; ++channel) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ // Compute the boundaries of the filter region clamped so as to
+ // ensure that the filter window fits in the input array.
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ uint8 max = 0;
+ for (int filter_y = filter_y_start; filter_y < filter_y_end;
+ ++filter_y) {
+ for (int filter_x = filter_x_start; filter_x < filter_x_end;
+ ++filter_x) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ max = std::max(
+ max,
+ input_data[Offset(input_dims, channel, in_x, in_y, batch)]);
+ }
+ }
+ max = std::max<uint8>(max, output_activation_min);
+ max = std::min<uint8>(max, output_activation_max);
+ output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ static_cast<uint8>(max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void LocalResponseNormalization(const float* input_data,
+ const Dims<4>& input_dims, int range,
+ float bias, float alpha, float beta,
+ float* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ const int begin_input_c = std::max(0, c - range);
+ const int end_input_c = std::min(depth, c + range);
+ float accum = 0.f;
+ for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
+ const float input_val =
+ input_data[Offset(input_dims, input_c, x, y, b)];
+ accum += input_val * input_val;
+ }
+ const float multiplier = std::pow(bias + alpha * accum, -beta);
+ output_data[Offset(output_dims, c, x, y, b)] =
+ input_data[Offset(input_dims, c, x, y, b)] * multiplier;
+ }
+ }
+ }
+ }
+}
+
+inline void Softmax(const float* input_data, const Dims<4>& input_dims,
+ float beta, float* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ // Find max element value which we'll use to ensure numerical stability
+ // taking advantage of the following equality:
+ // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C))
+ float max = std::numeric_limits<float>::lowest();
+ for (int c = 0; c < depth; ++c) {
+ max = std::max(max, input_data[Offset(input_dims, c, x, y, b)]);
+ }
+
+ // Compute sum.
+ float sum = 0.f;
+ for (int c = 0; c < depth; ++c) {
+ sum += std::exp((input_data[Offset(input_dims, c, x, y, b)] - max) *
+ beta);
+ }
+
+ // Compute result.
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ std::exp((input_data[Offset(input_dims, c, x, y, b)] - max) *
+ beta) /
+ sum;
+ }
+ }
+ }
+ }
+}
+
+inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const Dims<4>& output_dims) {
+ // The representation chosen for the input to the exp() function is Q5.26.
+ // We need to leave extra space since values that we skip might be as large as
+ // -32 before multiplying by input_beta_multiplier, and therefore as large as
+ // -16 afterwards. Note that exp(-8) is definitely not insignificant to
+ // accumulation, but exp(-16) definitely is.
+ static const int kScaledDiffIntegerBits = 5;
+ static const int kAccumulationIntegerBits = 12;
+ using FixedPointScaledDiff =
+ gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
+ using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int x = 0; x < width; ++x) {
+ for (int y = 0; y < height; ++y) {
+ uint8 max_in_row = 0;
+ for (int c = 0; c < depth; ++c) {
+ max_in_row =
+ std::max(max_in_row, input_data[Offset(input_dims, c, x, y, b)]);
+ }
+
+ FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
+ for (int c = 0; c < depth; ++c) {
+ int32 input_diff =
+ static_cast<int32>(input_data[Offset(input_dims, c, x, y, b)]) -
+ max_in_row;
+ if (input_diff >= diff_min) {
+ const int32 input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+ sum_of_exps =
+ sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
+ exp_on_negative_values(scaled_diff_f8));
+ }
+ }
+
+ int32 fixed_sum_of_exps = sum_of_exps.raw();
+ int headroom_plus_one =
+ CountLeadingZeros(static_cast<uint32>(fixed_sum_of_exps));
+ // This is the number of bits to the left of the binary point above 1.0.
+ // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
+ // no later adjustment will be needed.
+ int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
+ int32 shifted_sum_minus_one = static_cast<int32>(
+ (static_cast<uint32>(fixed_sum_of_exps) << headroom_plus_one) -
+ (static_cast<uint32>(1) << 31));
+
+ FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
+ FixedPoint0::FromRaw(shifted_sum_minus_one));
+
+ for (int c = 0; c < depth; ++c) {
+ int32 input_diff =
+ static_cast<int32>(input_data[Offset(input_dims, c, x, y, b)]) -
+ max_in_row;
+ if (input_diff >= diff_min) {
+ const int32 input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+
+ FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
+ int32 unsat_output = gemmlowp::RoundingDivideByPOT(
+ (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
+
+ output_data[Offset(output_dims, c, x, y, b)] = static_cast<uint8>(
+ std::max(std::min(unsat_output, static_cast<int32>(255)), 0));
+
+ } else {
+ output_data[Offset(output_dims, c, x, y, b)] = 0;
+ }
+ }
+ }
+ }
+ }
+}
+
+inline void Logistic(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ float result = 1.f / (1.f + std::exp(-val));
+ output_data[Offset(output_dims, c, x, y, b)] = result;
+ }
+ }
+ }
+ }
+}
+
+inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ const uint8 input_val_u8 = input_data[Offset(input_dims, c, x, y, b)];
+ const int32 input_val_centered =
+ static_cast<int32>(input_val_u8) - input_zero_point;
+ uint8 output_val;
+ if (input_val_centered <= -input_range_radius) {
+ output_val = 0;
+ } else if (input_val_centered >= input_range_radius) {
+ output_val = 255;
+ } else {
+ const int32 input_val_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_val_centered, input_multiplier, input_left_shift);
+ using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+ const FixedPoint4 input_val_f4 =
+ FixedPoint4::FromRaw(input_val_rescaled);
+ const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
+ using gemmlowp::RoundingDivideByPOT;
+ int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
+ if (output_val_s32 == 256) {
+ output_val_s32 = 255;
+ }
+ TFLITE_DCHECK_GE(output_val_s32, 0);
+ TFLITE_DCHECK_LE(output_val_s32, 255);
+ output_val = static_cast<uint8>(output_val_s32);
+ }
+ output_data[Offset(output_dims, c, x, y, b)] = output_val;
+ }
+ }
+ }
+ }
+}
+
+inline void Tanh(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ float result = std::tanh(val);
+ output_data[Offset(output_dims, c, x, y, b)] = result;
+ }
+ }
+ }
+ }
+}
+
+inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
+ int32 zero_point, double scale, float* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ int32 val = input_data[Offset(input_dims, c, x, y, b)];
+ float result = static_cast<float>(scale * (val - zero_point));
+ output_data[Offset(output_dims, c, x, y, b)] = result;
+ }
+ }
+ }
+ }
+}
+
+inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
+ float rmin, float rmax, float* output_data,
+ const Dims<4>& output_dims) {
+ // 0 should always be a representable value. Let's assume that the initial
+ // min,max range contains 0.
+ TFLITE_DCHECK_LE(rmin, 0.);
+ TFLITE_DCHECK_GE(rmax, 0.);
+
+ // Determine quantization parameters: zero_point, scale.
+ using Integer = uint8;
+ const Integer qmin = std::numeric_limits<Integer>::min();
+ const Integer qmax = std::numeric_limits<Integer>::max();
+ const float qmin_float = qmin;
+ const float qmax_float = qmax;
+ int32 zero_point = 0;
+ float scale = 0.f;
+ // If rmin==rmax, both must be zero per the above assertion,
+ // so we are done.
+ if (rmin != rmax) {
+ // First determine the scale.
+ scale = (rmax - rmin) / (qmax_float - qmin_float);
+
+ // Zero-point computation.
+ // First the initial floating-point computation. The zero-point can be
+ // determined from solving an affine equation for any known pair
+ // (real value, corresponding quantized value).
+ // We know two such pairs: (rmin, qmin) and (rmax, qmax).
+ // The arithmetic error on the zero point computed from either pair
+ // will be roughly machine_epsilon * (sum of absolute values of terms)
+ // so we want to use the variant that adds the smaller terms.
+ const float zero_point_from_min = qmin_float - rmin / scale;
+ const float zero_point_from_max = qmax_float - rmax / scale;
+ const float zero_point_from_min_error =
+ std::abs(qmin_float) + std::abs(rmin / scale);
+ const float zero_point_from_max_error =
+ std::abs(qmax_float) + std::abs(rmax / scale);
+
+ const float zero_point_float =
+ zero_point_from_min_error < zero_point_from_max_error
+ ? zero_point_from_min
+ : zero_point_from_max;
+
+ // Now we need to nudge the zero point to be an integer
+ // (our zero points are integer, and this is motivated by the requirement
+ // to be able to represent the real value "0" exactly as a quantized value,
+ // which is required in multiple places, for example in Im2col with SAME
+ // padding).
+ if (zero_point_float < qmin_float) {
+ zero_point = qmin;
+ } else if (zero_point_float > qmax_float) {
+ zero_point = qmax;
+ } else {
+ zero_point = static_cast<int32>(TfLiteRound(zero_point_float));
+ }
+ // The zero point should always be in the range of quantized value,
+ // [qmin, qmax].
+ TFLITE_DCHECK_GE(zero_point, qmin);
+ TFLITE_DCHECK_LE(zero_point, qmax);
+ }
+
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ const float src_val = input_data[Offset(input_dims, c, x, y, b)];
+ const float unclamped_quantized_val =
+ TfLiteRound(zero_point + src_val / scale);
+ const float quantized_val = std::min(
+ qmax_float, std::max(qmin_float, unclamped_quantized_val));
+ const float dst_val = scale * (quantized_val - zero_point);
+ output_data[Offset(output_dims, c, x, y, b)] = dst_val;
+ }
+ }
+ }
+ }
+}
+
+template <typename SrcT, typename DstT>
+inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
+ DstT* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ int offset = Offset(input_dims, c, x, y, b);
+ output_data[offset] = static_cast<DstT>(input_data[offset]);
+ }
+ }
+ }
+ }
+}
+
+inline void Floor(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ int offset = Offset(input_dims, c, x, y, b);
+ output_data[offset] = std::floor(input_data[offset]);
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void Gather(const T* input_data, const Dims<4>& input_dims,
+ int input_rank, const int32* coords_data,
+ const Dims<4>& coords_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]);
+ int stride = input_dims.strides[input_rank - 1];
+ T* out = output_data;
+
+ for (int i = 0; i < coords_dims.sizes[0]; i++) {
+ TFLITE_DCHECK_GE(coords_data[i], 0);
+ TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]);
+ const T* in = input_data + coords_data[i] * stride;
+ memcpy(out, in, sizeof(T) * stride);
+ out += stride;
+ }
+}
+
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ int32 input_height = ArraySize(input_dims, 2);
+ int32 input_width = ArraySize(input_dims, 1);
+ int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
+ int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
+ int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
+ float height_scale = static_cast<float>(input_height) / output_height;
+ float width_scale = static_cast<float>(input_width) / output_width;
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < output_height; ++y) {
+ float input_y = y * height_scale;
+ int32 y0 = static_cast<int32>(std::floor(input_y));
+ int32 y1 = std::min(y0 + 1, input_height - 1);
+ for (int x = 0; x < output_width; ++x) {
+ float input_x = x * width_scale;
+ int32 x0 = static_cast<int32>(std::floor(input_x));
+ int32 x1 = std::min(x0 + 1, input_width - 1);
+ for (int c = 0; c < depth; ++c) {
+ float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] *
+ (1 - (input_y - y0)) *
+ (1 - (input_x - x0)) +
+ input_data[Offset(input_dims, c, x0, y1, b)] *
+ (input_y - y0) * (1 - (input_x - x0)) +
+ input_data[Offset(input_dims, c, x1, y0, b)] *
+ (1 - (input_y - y0)) * (input_x - x0) +
+ input_data[Offset(input_dims, c, x1, y1, b)] *
+ (input_y - y0) * (input_x - x0);
+ output_data[Offset(output_dims, c, x, y, b)] = interpolation;
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* paddings_data,
+ const Dims<4>& paddings_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ const int output_batch_size = ArraySize(output_dims, 3);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int input_batch_size = ArraySize(input_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int depth = ArraySize(input_dims, 0);
+ const int block_shape_height = block_shape_data[0];
+ const int block_shape_width = block_shape_data[1];
+ const int padding_top = paddings_data[0];
+ const int padding_left = paddings_data[2];
+
+ for (int out_b = 0; out_b < output_batch_size; ++out_b) {
+ int input_batch = out_b % input_batch_size;
+ int shift_w = (out_b / input_batch_size) % block_shape_width;
+ int shift_h = (out_b / input_batch_size) / block_shape_width;
+ for (int out_h = 0; out_h < output_height; ++out_h) {
+ for (int out_w = 0; out_w < output_width; ++out_w) {
+ T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b);
+ if (out_h * block_shape_height < padding_top ||
+ out_h * block_shape_height >= padding_top + input_height ||
+ out_w * block_shape_width < padding_left ||
+ out_w * block_shape_width >= padding_left + input_width) {
+ memset(out, 0, depth * sizeof(T));
+ } else {
+ const T* in =
+ input_data +
+ Offset(input_dims, 0,
+ (out_w * block_shape_width + shift_w) - padding_left,
+ (out_h * block_shape_height + shift_h) - padding_top,
+ input_batch);
+ memcpy(out, in, depth * sizeof(T));
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ const int output_batch_size = ArraySize(output_dims, 3);
+ const int input_batch_size = ArraySize(input_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int depth = ArraySize(input_dims, 0);
+ const int block_shape_width = block_shape_data[1];
+ const int block_shape_height = block_shape_data[0];
+
+ for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ int out_batch = in_batch % output_batch_size;
+ int out_w = in_w * block_shape_width +
+ (in_batch / output_batch_size) % block_shape_width;
+ int out_h = in_h * block_shape_height +
+ (in_batch / output_batch_size) / block_shape_width;
+ T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch);
+ const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch);
+ memcpy(out, in, depth * sizeof(T));
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims) {
+ const int output_batch = ArraySize(output_dims, 3);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_depth = ArraySize(output_dims, 0);
+
+ const int left_b_padding = left_paddings[3];
+ const int left_h_padding = left_paddings[2];
+ const int left_w_padding = left_paddings[1];
+ const int left_d_padding = left_paddings[0];
+
+ const int right_b_padding = right_paddings[3];
+ const int right_h_padding = right_paddings[2];
+ const int right_w_padding = right_paddings[1];
+ const int right_d_padding = right_paddings[0];
+
+ const T* in_ptr = input_data;
+ T* out_ptr = output_data;
+ for (int out_b = 0; out_b < output_batch; ++out_b) {
+ for (int out_h = 0; out_h < output_height; ++out_h) {
+ for (int out_w = 0; out_w < output_width; ++out_w) {
+ for (int out_d = 0; out_d < output_depth; ++out_d) {
+ if (out_b < left_b_padding ||
+ out_b >= output_batch - right_b_padding ||
+ out_h < left_h_padding ||
+ out_h >= output_height - right_h_padding ||
+ out_w < left_w_padding ||
+ out_w >= output_width - right_w_padding ||
+ out_d < left_d_padding ||
+ out_d >= output_depth - right_d_padding) {
+ *out_ptr++ = 0;
+ } else {
+ *out_ptr++ = *in_ptr++;
+ }
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+ int begin_mask, int end_mask,
+ const std::vector<int>& starts,
+ const std::vector<int>& stops,
+ const std::vector<int>& strides, T* output_data,
+ const Dims<4>& output_dims) {
+ const int start_b = (begin_mask & 8) ? 0 : starts[3];
+ const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3];
+ const int start_h = (begin_mask & 4) ? 0 : starts[2];
+ const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2];
+ const int start_w = (begin_mask & 2) ? 0 : starts[1];
+ const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1];
+ const int start_d = (begin_mask & 1) ? 0 : starts[0];
+ const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0];
+
+ T* out_ptr = output_data;
+ for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) {
+ for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) {
+ for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) {
+ for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) {
+ *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ // TODO(dkalenichenko): This op only supports 4D tensors.
+ TFLITE_DCHECK_EQ(begin.size(), 4);
+ TFLITE_DCHECK_EQ(size.size(), 4);
+ const int start_b = begin[3];
+ const int stop_b =
+ size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3];
+ const int start_h = begin[2];
+ const int stop_h =
+ size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2];
+ const int start_w = begin[1];
+ const int stop_w =
+ size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1];
+ const int start_d = begin[0];
+ const int stop_d =
+ size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0];
+
+ T* out_ptr = output_data;
+ for (int in_b = start_b; in_b < stop_b; ++in_b) {
+ for (int in_h = start_h; in_h < stop_h; ++in_h) {
+ for (int in_w = start_w; in_w < stop_w; ++in_w) {
+ for (int in_d = start_d; in_d < stop_d; ++in_d) {
+ *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void Mean(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& reduction_indices, T* output_data,
+ const Dims<4>& output_dims) {
+ const int output_batch = ArraySize(output_dims, 3);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_depth = ArraySize(output_dims, 0);
+
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+
+ // The current implementation only supports simultaneous reduction over
+ // width and height.
+ TFLITE_DCHECK_EQ(reduction_indices.size(), 2);
+ TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) ||
+ (reduction_indices[0] == 2 && reduction_indices[1] == 1));
+ TFLITE_DCHECK_EQ(output_height, 1);
+ TFLITE_DCHECK_EQ(output_width, 1);
+
+ for (int out_b = 0; out_b < output_batch; ++out_b) {
+ for (int out_d = 0; out_d < output_depth; ++out_d) {
+ float value = 0;
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)];
+ }
+ }
+ output_data[Offset(output_dims, out_d, 0, 0, out_b)] =
+ value / (input_width * input_height);
+ }
+ }
+}
+
+template <typename T>
+void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
+ const Dims<4>& input2_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ int batches = MatchingArraySize(input1_dims, 3, output_dims, 3);
+ int input_height = MatchingArraySize(input1_dims, 2, output_dims, 2);
+ int input_width = MatchingArraySize(input1_dims, 1, output_dims, 1);
+ int depth = MatchingArraySize(input1_dims, 0, output_dims, 0);
+
+ auto min_value = input2_data[0];
+
+ for (int b = 0; b < batches; b++) {
+ for (int y = 0; y < input_height; y++) {
+ for (int x = 0; x < input_width; x++) {
+ for (int c = 0; c < depth; c++) {
+ int offset = Offset(input1_dims, c, x, y, b);
+ output_data[offset] =
+ input1_data[offset] > min_value ? min_value : input1_data[offset];
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ int batches = MatchingArraySize(input1_dims, 3, output_dims, 3);
+ int input_height = MatchingArraySize(input1_dims, 2, output_dims, 2);
+ int input_width = MatchingArraySize(input1_dims, 1, output_dims, 1);
+ int depth = MatchingArraySize(input1_dims, 0, output_dims, 0);
+
+ auto max_value = input2_data[0];
+
+ for (int b = 0; b < batches; b++) {
+ for (int y = 0; y < input_height; y++) {
+ for (int x = 0; x < input_width; x++) {
+ for (int c = 0; c < depth; c++) {
+ int offset = Offset(input1_dims, c, x, y, b);
+ output_data[offset] =
+ input1_data[offset] < max_value ? max_value : input1_data[offset];
+ }
+ }
+ }
+ }
+}
+
+} // namespace reference_ops
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/round.h b/tensorflow/contrib/lite/kernels/internal/round.h
new file mode 100644
index 0000000000..38525b0e20
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/round.h
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_
+
+#include <cmath>
+
+namespace tflite {
+
+// TODO(aselle): See if we can do this only on jdk. Also mikecase, check
+// if you need this for java host build.
+#if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
+template <class T>
+inline float TfLiteRound(const float x) {
+ return ::round(x);
+}
+inline double TfLiteRound(const double x) { return ::round(x); }
+#else
+template <class T>
+inline T TfLiteRound(const T x) {
+ return std::round(x);
+}
+#endif
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
new file mode 100644
index 0000000000..ee4111e041
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
+
+#include <vector>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+template <typename T>
+inline T* GetTensorData(TfLiteTensor* tensor);
+
+template <>
+inline float* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline int32_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline int64_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? reinterpret_cast<int64_t*>(tensor->data.raw)
+ : nullptr;
+}
+
+inline int RemapDim(int max_dimensions, int d) {
+ return max_dimensions - d - 1;
+}
+
+// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
+// even if the original tensors were not 4D. We should consider rewriting them
+// to take a more generic 'shape' object.
+inline Dims<4> GetTensorDims(const int data[], const int size) {
+ Dims<4> d;
+ for (int i = 0; i < 4; ++i) {
+ int src = size - i - 1;
+ if (src >= 0) {
+ d.sizes[i] = data[src];
+ } else {
+ d.sizes[i] = 1;
+ }
+ }
+ d.strides[0] = 1;
+ for (int i = 1; i < 4; i++) {
+ d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
+ }
+ return d;
+}
+
+inline Dims<4> GetTensorDims(std::vector<int32_t> data) {
+ return GetTensorDims(data.data(), data.size());
+}
+
+inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
+ if (tensor == nullptr) {
+ return Dims<4>();
+ }
+
+ auto* dims = tensor->dims;
+ return GetTensorDims(dims->data, dims->size);
+}
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
new file mode 100644
index 0000000000..bf2068d320
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+TEST(TensorTest, GetTensorDims4D) {
+ Dims<4> d = GetTensorDims({2, 3, 4, 5});
+ EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 2));
+ EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60));
+}
+
+TEST(TensorTest, GetTensorDims3D) {
+ Dims<4> d = GetTensorDims({3, 4, 5});
+ EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 1));
+ EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60));
+}
+
+TEST(TensorTest, GetTensorDims2D) {
+ Dims<4> d = GetTensorDims({4, 5});
+ EXPECT_THAT(d.sizes, ElementsAre(5, 4, 1, 1));
+ EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 20));
+}
+
+TEST(TensorTest, GetTensorDims1D) {
+ Dims<4> d = GetTensorDims({5});
+ EXPECT_THAT(d.sizes, ElementsAre(5, 1, 1, 1));
+ EXPECT_THAT(d.strides, ElementsAre(1, 5, 5, 5));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc
new file mode 100644
index 0000000000..904a97803a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc
@@ -0,0 +1,27 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+
+#ifndef USE_NEON
+#if defined(__ARM_NEON__) || defined(__ARM_NEON)
+#define USE_NEON
+#endif // defined(__ARM_NEON__) || defined(__ARM_NEON)
+#endif // USE_NEON
+
+#ifdef USE_NEON
+#include "tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h"
+#else
+#include "tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h"
+#endif // USE_NEON
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
new file mode 100644
index 0000000000..0e69ef5982
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -0,0 +1,116 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+
+namespace tflite {
+namespace tensor_utils {
+
+// Limit a float input f betweeen +abs_limit and -abs_limit.
+float Clip(float f, float abs_limit);
+
+// Multiply a matrix by a batch vector, and store results in a batch-size
+// vector using a stride value provided in result_stride. 'result_stride' shows
+// how the number of elements between consecutive result values. For example
+// result_stride = 1, will cause the output to look like this:
+// [O_1, 0_2, ... O_rows] in memory, but result_stride = 3, will cause it to be
+// arranged like this in memory: [O_1, x, x, 0_2, x, x, ..., O_rows]
+void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
+ int m_cols, const float* vector,
+ int n_batch, float* result,
+ int result_stride);
+
+// Cwise product of two vectors.
+void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
+ int v_size, float* result);
+
+// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the
+// assumption here is that result array is initialized to valid values.
+void VectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2, int v_size,
+ float* result);
+
+// Dot product of two vectors.
+float VectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size);
+
+// Dot product of two batch vectors of size n_batch * v_size:
+// vector1 = [x_1_1, x_1_2, ..., x_1_vsize,
+// x_2_1, x_2_2, ..., x_2_vsize,
+// ...
+// x_nbatch_1,..., x_nbatch_vsize]
+// vector2 = [y_1_1, y_1_2, ..., y_1_vsize,
+// y_2_1, y_2_2, ..., y_2_vsize,
+// ...
+// y_nbatch_1,..., y_nbatch_vsize]
+// Then result will be a vector of n_batch size which will be saved with a
+// stride of result_stride in memory starting from 'result':
+// [x_1_1 * y_1_1 + x_1_2 * y_1_2 + ... + x_1_vsize * y_1_vsize,
+// x_2_1 * y_2_1 + x_2_2 * y_2_2 + ... + x_2_vsize * y_2_vsize,
+// ...
+// x_nbatch_1 * y_nbatch_1 + ... + x_nbatch_vsize * y_nbatch_vsize]
+void BatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride);
+
+// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
+// operation, the assumption here is that result array is initialized to valid
+// values.
+void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result);
+
+// Batch vector initialization with another vector.
+void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
+ float* batch_vector);
+
+// Apply sigmoid to elements of a vector.
+void ApplySigmoidToVector(const float* vector, int v_size, float* result);
+
+// Apply activation function to elements of a vector.
+void ApplyActivationToVector(const float* vector, int v_size,
+ TfLiteFusedActivation activation, float* result);
+
+// Copy vector to another vector.
+void CopyVector(const float* vector, int v_size, float* result);
+
+// Compute "1.0f - elements of vector" (used in CIFG).
+void Sub1Vector(const float* vector, int v_size, float* result);
+
+// Fill vector with 0.f.
+void ZeroVector(float* vector, int v_size);
+
+// Clip elements of a vector using a abs_limit value.
+void ClipVector(const float* vector, int v_size, float abs_limit,
+ float* result);
+
+// Shift left a vector in place with v_size size.
+void VectorShiftLeft(float* vector, int v_size, float shift_value);
+
+// Reduce-sum on a float input vector:
+// input_vector: float pointer to input vector.
+// output_vector: float pointer to vector.
+// output_size: output vector size.
+// reduction_size: number of consecutive elements from input vector which are
+// added to get one element of output.
+void ReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size);
+} // namespace tensor_utils
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
new file mode 100644
index 0000000000..588f1a428b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -0,0 +1,192 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include <gmock/gmock.h>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+namespace tflite {
+namespace tensor_utils {
+
+TEST(uKernels, ClipTest) {
+ constexpr int kVectorSize = 10;
+ constexpr float kAbsLimit = 2.0;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0,
+ -2.5, 3.0, -3.5, 4.0, -4.5};
+ std::vector<float> output(kVectorSize);
+ ClipVector(input, kVectorSize, kAbsLimit, output.data());
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear(
+ {0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0})));
+}
+
+TEST(uKernels, MatrixBatchVectorMultiplyAccumulateTest) {
+ constexpr int kRow = 3;
+ constexpr int kCol = 4;
+ constexpr int kBatch = 2;
+ static float matrix[kRow * kCol] = {1.0, 2.0, 3.0, 4.0, //
+ -1.0, -2.0, -3.0, -4.0, //
+ 1.0, -2.0, 3.0, -4.0};
+ static float vector[kCol * kBatch] = {1.0, -1.0, 1.0, -1.0, //
+ 2.0, -2.0, 2.0, -2.0};
+ std::vector<float> output(kRow * kBatch);
+ std::fill(output.begin(), output.end(), 3.0);
+ MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch,
+ output.data(), /*result_stride=*/1);
+ EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({1., 5., 13., //
+ -1., 7., 23.})));
+
+ std::vector<float> output_with_stride2(kRow * kBatch * 2);
+ std::fill(output_with_stride2.begin(), output_with_stride2.end(), 3.0);
+ MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch,
+ output_with_stride2.data(),
+ /*result_stride=*/2);
+ EXPECT_THAT(output_with_stride2,
+ ElementsAreArray(ArrayFloatNear({1., 3., 5., 3., 13., 3., //
+ -1., 3., 7., 3., 23., 3.})));
+}
+
+TEST(uKernels, VectorVectorCwiseProductTest) {
+ constexpr int kVectorSize = 10;
+ static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0,
+ -2.5, 3.0, -3.5, 4.0, -4.5};
+ static float input2[kVectorSize] = {0.1, -0.1, 0.1, -0.1, 0.1,
+ -0.1, 0.1, -0.1, 0.1, -0.1};
+ std::vector<float> output(kVectorSize);
+ VectorVectorCwiseProduct(input1, input2, kVectorSize, output.data());
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear(
+ {0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45})));
+}
+
+TEST(uKernels, VectorVectorCwiseProductAccumulateTest) {
+ constexpr int kVectorSize = 10;
+ static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0,
+ -2.5, 3.0, -3.5, 4.0, -4.5};
+ static float input2[kVectorSize] = {0.1, -0.1, 0.1, -0.1, 0.1,
+ -0.1, 0.1, -0.1, 0.1, -0.1};
+ std::vector<float> output(kVectorSize);
+ std::fill(output.begin(), output.end(), 1.0);
+ VectorVectorCwiseProductAccumulate(input1, input2, kVectorSize,
+ output.data());
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear(
+ {1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45})));
+}
+
+TEST(uKernels, VectorBatchVectorAssignTest) {
+ constexpr int kVectorSize = 5;
+ constexpr int kBatchSize = 3;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
+ std::vector<float> output(kVectorSize * kBatchSize);
+ VectorBatchVectorAssign(input, kVectorSize, kBatchSize, output.data());
+ EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
+ {0.0, -0.5, 1.0, -1.5, 2.0, 0.0, -0.5, 1.0, -1.5, 2.0,
+ 0.0, -0.5, 1.0, -1.5, 2.0})));
+}
+
+TEST(uKernels, ApplySigmoidToVectorTest) {
+ constexpr int kVectorSize = 5;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
+ std::vector<float> output(kVectorSize);
+ ApplySigmoidToVector(input, kVectorSize, output.data());
+ EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
+ {0.5, 0.377541, 0.731059, 0.182426, 0.880797})));
+}
+
+TEST(uKernels, ApplyActivationToVectorTest) {
+ constexpr int kVectorSize = 5;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
+ std::vector<float> output(kVectorSize);
+ ApplyActivationToVector(input, kVectorSize, kTfLiteActRelu, output.data());
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear({0.0, 0.0, 1.0, 0.0, 2.0})));
+
+ ApplyActivationToVector(input, kVectorSize, kTfLiteActTanh, output.data());
+ EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
+ {0.0, -0.462117, 0.761594, -0.905148, 0.964028})));
+}
+
+TEST(uKernels, CopyVectorTest) {
+ constexpr int kVectorSize = 5;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
+ std::vector<float> output(kVectorSize);
+ CopyVector(input, kVectorSize, output.data());
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear({0.0, -0.5, 1.0, -1.5, 2.0})));
+}
+
+TEST(uKernels, Sub1VectorTest) {
+ constexpr int kVectorSize = 5;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
+ std::vector<float> output(kVectorSize);
+ Sub1Vector(input, kVectorSize, output.data());
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear({1.0, 1.5, 0.0, 2.5, -1.0})));
+}
+
+TEST(uKernels, ZeroVectorTest) {
+ constexpr int kVectorSize = 5;
+ std::vector<float> output(kVectorSize);
+ ZeroVector(output.data(), kVectorSize);
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear({0.0, 0.0, 0.0, 0.0, 0.0})));
+}
+
+TEST(uKernels, BatchVectorBatchVectorDotProductTest) {
+ constexpr int kVectorSize = 5;
+ constexpr int kBatch = 2;
+ static float input1[kVectorSize * kBatch] = {0.0, -0.5, 1.0, -1.5, 2.0,
+ -2.5, 3.0, -3.5, 4.0, -4.5};
+ static float input2[kVectorSize * kBatch] = {0.1, -0.1, 0.1, -0.1, 0.1,
+ -0.1, 0.1, -0.1, 0.1, -0.1};
+ std::vector<float> output(kBatch);
+ BatchVectorBatchVectorDotProduct(input1, input2, kVectorSize, kBatch,
+ output.data(), /*result_stride=*/1);
+ EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({0.5, 1.75})));
+}
+
+TEST(uKernels, VectorShiftLeftTest) {
+ constexpr int kVectorSize = 5;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
+ std::vector<float> result(kVectorSize);
+ VectorShiftLeft(input, kVectorSize, 3.0);
+ result.assign(input, input + kVectorSize);
+ EXPECT_THAT(result,
+ ElementsAreArray(ArrayFloatNear({-0.5, 1.0, -1.5, 2.0, 3.0})));
+}
+
+TEST(uKernels, ReductionSumVectorTest) {
+ constexpr int kInputVectorSize = 10;
+ constexpr int kOutputVectorSize1 = 5;
+ constexpr int kReductionSize1 = 2;
+ static float input[kInputVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0,
+ 0.0, -0.5, 1.0, 1.0, 2.0};
+ std::vector<float> result1(kOutputVectorSize1);
+ ReductionSumVector(input, result1.data(), kOutputVectorSize1,
+ kReductionSize1);
+ EXPECT_THAT(result1,
+ ElementsAreArray(ArrayFloatNear({-0.5, -0.5, 2.0, 0.5, 3.0})));
+
+ constexpr int kOutputVectorSize2 = 2;
+ constexpr int kReductionSize2 = 5;
+ std::vector<float> result2(kOutputVectorSize2);
+ ReductionSumVector(input, result2.data(), kOutputVectorSize2,
+ kReductionSize2);
+ EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5})));
+}
+
+} // namespace tensor_utils
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
new file mode 100644
index 0000000000..07f1cb4004
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -0,0 +1,81 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
+
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+
+namespace tflite {
+
+enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
+
+template <int N>
+struct Dims {
+ int sizes[N];
+ int strides[N];
+};
+
+inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
+ TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
+ TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
+ TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
+ TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
+ return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
+ i3 * dims.strides[3];
+}
+
+// Get array size, DCHECKing that the dim index is in range.
+template <int N>
+int ArraySize(const Dims<N>& array, int index) {
+ TFLITE_DCHECK(index >= 0 && index < N);
+ return array.sizes[index];
+}
+
+// Get common array size, DCHECKing that they all agree.
+template <typename ArrayType1, typename ArrayType2>
+int MatchingArraySize(const ArrayType1& array1, int index1,
+ const ArrayType2& array2, int index2) {
+ TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
+ return ArraySize(array1, index1);
+}
+
+template <typename ArrayType1, typename ArrayType2, typename... Args>
+int MatchingArraySize(const ArrayType1& array1, int index1,
+ const ArrayType2& array2, int index2, Args... args) {
+ TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
+ return MatchingArraySize(array1, index1, args...);
+}
+
+inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
+ int max_offset = 0;
+ for (int i = 0; i < 4; i++) {
+ max_offset += (dims.sizes[i] - 1) * dims.strides[i];
+ }
+ return max_offset + 1;
+}
+
+template <int N>
+bool IsPackedWithoutStrides(const Dims<N>& dims) {
+ int expected_stride = 1;
+ for (int d = 0; d < N; d++) {
+ if (dims.strides[d] != expected_stride) return false;
+ expected_stride *= dims.sizes[d];
+ }
+ return true;
+}
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc
new file mode 100644
index 0000000000..b0546c00cf
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/kernel_util.cc
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include <algorithm>
+#include <cmath>
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+
+namespace tflite {
+
+TfLiteStatus GetQuantizedConvolutionMultipler(
+ TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter,
+ TfLiteTensor* bias, TfLiteTensor* output, double* multiplier) {
+ const double input_product_scale = input->params.scale * filter->params.scale;
+ const double bias_scale = bias->params.scale;
+ const double output_scale = output->params.scale;
+
+ // TODO(ahentz): The following conditions must be guaranteed by the training
+ // pipeline.
+ TF_LITE_ENSURE(context, std::abs(input_product_scale - bias_scale) <=
+ 1e-6 * std::min(input_product_scale, bias_scale));
+ TF_LITE_ENSURE(context, input_product_scale >= 0);
+ TF_LITE_ENSURE(context, input_product_scale < output_scale);
+
+ *multiplier = input_product_scale / output_scale;
+
+ return kTfLiteOk;
+}
+
+void CalculateActivationRangeUint8(TfLiteFusedActivation activation,
+ TfLiteTensor* output, int32_t* act_min,
+ int32_t* act_max) {
+ const int32_t qmin = std::numeric_limits<uint8_t>::min();
+ const int32_t qmax = std::numeric_limits<uint8_t>::max();
+
+ const auto scale = output->params.scale;
+ const auto zero_point = output->params.zero_point;
+
+ auto quantize = [scale, zero_point](float f) {
+ return zero_point + static_cast<int32_t>(TfLiteRound(f / scale));
+ };
+
+ if (activation == kTfLiteActRelu) {
+ *act_min = std::max(qmin, quantize(0.0));
+ *act_max = qmax;
+ } else if (activation == kTfLiteActRelu6) {
+ *act_min = std::max(qmin, quantize(0.0));
+ *act_max = std::min(qmax, quantize(6.0));
+ } else if (activation == kTfLiteActRelu1) {
+ *act_min = std::max(qmin, quantize(-1.0));
+ *act_max = std::min(qmax, quantize(1.0));
+ } else {
+ *act_min = qmin;
+ *act_max = qmax;
+ }
+}
+
+void CalculateActivationRangeFloat(TfLiteFusedActivation activation,
+ float* activation_min,
+ float* activation_max) {
+ if (activation == kTfLiteActRelu) {
+ *activation_min = 0.f;
+ *activation_max = std::numeric_limits<float>::max();
+ } else if (activation == kTfLiteActRelu6) {
+ *activation_min = 0.f;
+ *activation_max = 6.f;
+ } else if (activation == kTfLiteActRelu1) {
+ *activation_min = -1.f;
+ *activation_max = 1.f;
+ } else {
+ *activation_min = std::numeric_limits<float>::lowest();
+ *activation_max = std::numeric_limits<float>::max();
+ }
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h
new file mode 100644
index 0000000000..25556ae456
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/kernel_util.h
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+
+inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; }
+inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
+ return t->dims->data[dim];
+}
+inline TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node,
+ int index) {
+ return &context->tensors[node->inputs->data[index]];
+}
+inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node,
+ int index) {
+ return &context->tensors[node->outputs->data[index]];
+}
+inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
+inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
+
+inline TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context,
+ const TfLiteNode* node, int index) {
+ const bool use_tensor = node->inputs->data[index] != kOptionalTensor;
+ if (use_tensor) {
+ return &context->tensors[node->inputs->data[index]];
+ }
+ return nullptr;
+}
+
+// Calculates the multiplication factor for a quantized convolution (or
+// quantized depthwise convolution) involving the given tensors. Returns an
+// error if the scales of the tensors are not compatible.
+TfLiteStatus GetQuantizedConvolutionMultipler(
+ TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter,
+ TfLiteTensor* bias, TfLiteTensor* output, double* multiplier);
+
+// Calculates the useful range of an activation layer given its activation
+// tensor.
+void CalculateActivationRangeUint8(TfLiteFusedActivation activation,
+ TfLiteTensor* output, int32_t* act_min,
+ int32_t* act_max);
+void CalculateActivationRangeFloat(TfLiteFusedActivation activation,
+ float* activation_min,
+ float* activation_max);
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc
new file mode 100644
index 0000000000..f43aa372b6
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/l2norm.cc
@@ -0,0 +1,112 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace l2norm {
+
+// This file has two implementation of L2Norm.
+enum KernelType {
+ kReference,
+ kGenericOptimized,
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteL2NormParams*>(node->builtin_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // TODO(ahentz): Our current implementations rely on the inputs being 4D.
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+
+ // TODO(ahentz): Our current implementations only support float32.
+ TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ // TODO(ahentz): For some reason our implementations don't support
+ // activations.
+ TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+ output_size->data[0] = input->dims->data[0];
+ output_size->data[1] = input->dims->data[1];
+ output_size->data[2] = input->dims->data[2];
+ output_size->data[3] = input->dims->data[3];
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (output->type == kTfLiteFloat32) {
+#define TF_LITE_L2NORM(type) \
+ type::L2Normalization<FusedActivationFunctionType::kNone>( \
+ GetTensorData<float>(input), GetTensorDims(input), \
+ GetTensorData<float>(output), GetTensorDims(output))
+
+ if (kernel_type == kReference) {
+ TF_LITE_L2NORM(reference_ops);
+ }
+ if (kernel_type == kGenericOptimized) {
+ TF_LITE_L2NORM(optimized_ops);
+ }
+#undef TF_LITE_L2NORM
+ } else {
+ context->ReportError(context, "Inputs and outputs not all float types.");
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace l2norm
+
+TfLiteRegistration* Register_L2NORM_REF() {
+ static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare,
+ l2norm::Eval<l2norm::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_L2NORM_GENERIC_OPT() {
+ static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare,
+ l2norm::Eval<l2norm::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_L2_NORMALIZATION() {
+ return Register_L2NORM_GENERIC_OPT();
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc
new file mode 100644
index 0000000000..b1db89b8bd
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class L2NormOpModel : public SingleOpModel {
+ public:
+ L2NormOpModel(std::initializer_list<int> input_shape,
+ ActivationFunctionType activation_type) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions,
+ CreateL2NormOptions(builder_, activation_type).Union());
+ BuildInterpreter({input_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(L2NormOpTest, SimpleTest) {
+ L2NormOpModel m({1, 1, 1, 6}, ActivationFunctionType_NONE);
+ m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc
new file mode 100644
index 0000000000..c1c70d0dfa
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc
@@ -0,0 +1,109 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace local_response_norm {
+
+// This file has two implementation of LocalResponseNorm.
+enum KernelType {
+ kReference,
+ kGenericOptimized,
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+
+ TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+ output_size->data[0] = input->dims->data[0];
+ output_size->data[1] = input->dims->data[1];
+ output_size->data[2] = input->dims->data[2];
+ output_size->data[3] = input->dims->data[3];
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteLocalResponseNormParams*>(node->builtin_data);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (output->type == kTfLiteFloat32) {
+#define TF_LITE_LOCAL_RESPONSE_NORM(type) \
+ type::LocalResponseNormalization( \
+ GetTensorData<float>(input), GetTensorDims(input), params->radius, \
+ params->bias, params->alpha, params->beta, GetTensorData<float>(output), \
+ GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_LOCAL_RESPONSE_NORM(reference_ops);
+ }
+ if (kernel_type == kGenericOptimized) {
+ TF_LITE_LOCAL_RESPONSE_NORM(optimized_ops);
+ }
+#undef TF_LITE_LOCAL_RESPONSE_NORM
+ } else {
+ context->ReportError(context, "Inputs and outputs not all float types.");
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace local_response_norm
+
+TfLiteRegistration* Register_LOCAL_RESPONSE_NORM_REF() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, local_response_norm::Prepare,
+ local_response_norm::Eval<local_response_norm::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_LOCAL_RESPONSE_NORM_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, local_response_norm::Prepare,
+ local_response_norm::Eval<local_response_norm::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION() {
+ return Register_LOCAL_RESPONSE_NORM_GENERIC_OPT();
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/local_response_norm_test.cc b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc
new file mode 100644
index 0000000000..63a8b0a3d0
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc
@@ -0,0 +1,101 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class LocalResponseNormOpModel : public SingleOpModel {
+ public:
+ LocalResponseNormOpModel(std::initializer_list<int> input_shape, int radius,
+ float bias, float alpha, float beta) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
+ BuiltinOptions_LocalResponseNormalizationOptions,
+ CreateLocalResponseNormalizationOptions(builder_, radius, bias,
+ alpha, beta)
+ .Union());
+ BuildInterpreter({input_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(LocalResponseNormOpTest, SameAsL2Norm) {
+ LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0,
+ /*alpha=*/1.0, /*beta=*/0.5);
+ m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
+ m.Invoke();
+ // The result is every input divided by 2.
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})));
+}
+
+TEST(LocalResponseNormOpTest, WithAlpha) {
+ LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0,
+ /*alpha=*/4.0, /*beta=*/0.5);
+ m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
+ m.Invoke();
+ // The result is every input divided by 3.
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ {-0.275, 0.15, 0.175, 0.3, -0.175, 0.025})));
+}
+
+TEST(LocalResponseNormOpTest, WithBias) {
+ LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/9.0,
+ /*alpha=*/4.0, /*beta=*/0.5);
+ m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
+ m.Invoke();
+ // The result is every input divided by 5.
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.22, 0.12, 0.14, 0.24, -0.14, 0.02})));
+}
+
+TEST(LocalResponseNormOpTest, SmallRadius) {
+ LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/2, /*bias=*/9.0,
+ /*alpha=*/4.0, /*beta=*/0.5);
+ m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {-0.264926, 0.125109, 0.140112, 0.267261, -0.161788, 0.0244266})));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc
new file mode 100644
index 0000000000..5f73b56ed9
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc
@@ -0,0 +1,204 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// LSH Projection projects an input to a bit vector via locality senstive
+// hashing.
+//
+// Options:
+// Sparse:
+// Computed bit vector is considered to be sparse.
+// Each output element is an int32 made up by multiple bits computed from
+// hash functions.
+//
+// Dense:
+// Computed bit vector is considered to be dense. Each output element is
+// either 0 or 1 that represents a bit.
+//
+// Input:
+// Tensor[0]: Hash functions. Dim.size == 2, DataType: Float.
+// Tensor[0].Dim[0]: Num of hash functions.
+// Tensor[0].Dim[1]: Num of projected output bits generated by
+// each hash function.
+// In sparse case, Tensor[0].Dim[1] + ceil( log2(Tensor[0].Dim[0] )) <= 32.
+//
+// Tensor[1]: Input. Dim.size >= 1, No restriction on DataType.
+// Tensor[2]: Optional, Weight. Dim.size == 1, DataType: Float.
+// If not set, each element of input is considered to have same
+// weight of 1.0 Tensor[1].Dim[0] == Tensor[2].Dim[0]
+//
+// Output:
+// Sparse:
+// Output.Dim == { Tensor[0].Dim[0] }
+// A tensor of int32 that represents hash signatures,
+//
+// NOTE: To avoid collisions across hash functions, an offset value of
+// k * (1 << Tensor[0].Dim[1]) will be added to each signature,
+// k is the index of the hash function.
+// Dense:
+// Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] }
+// A flattened tensor represents projected bit vectors.
+
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <iostream>
+#include <limits>
+#include <memory>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include <farmhash.h>
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace lsh_projection {
+
+TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data);
+ TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* hash = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2);
+ // Support up to 32 bits.
+ TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32);
+
+ TfLiteTensor* input = GetInput(context, node, 1);
+ TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
+
+ if (NumInputs(node) == 3) {
+ TfLiteTensor* weight = GetInput(context, node, 2);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0),
+ SizeOfDimension(input, 0));
+ }
+
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
+ switch (params->type) {
+ case kTfLiteLshProjectionSparse:
+ outputSize->data[0] = SizeOfDimension(hash, 0);
+ break;
+ case kTfLiteLshProjectionDense:
+ outputSize->data[0] = SizeOfDimension(hash, 0) * SizeOfDimension(hash, 1);
+ break;
+ default:
+ return kTfLiteError;
+ }
+ return context->ResizeTensor(context, output, outputSize);
+}
+
+// Compute sign bit of dot product of hash(seed, input) and weight.
+// NOTE: use float as seed, and convert it to double as a temporary solution
+// to match the trained model. This is going to be changed once the new
+// model is trained in an optimized method.
+//
+int RunningSignBit(const TfLiteTensor* input, const TfLiteTensor* weight,
+ float seed) {
+ double score = 0.0;
+ int input_item_bytes = input->bytes / SizeOfDimension(input, 0);
+ char* input_ptr = input->data.raw;
+
+ const size_t seed_size = sizeof(float);
+ const size_t key_bytes = sizeof(float) + input_item_bytes;
+ std::unique_ptr<char[]> key(new char[key_bytes]);
+
+ for (int i = 0; i < SizeOfDimension(input, 0); ++i) {
+ // Create running hash id and value for current dimension.
+ memcpy(key.get(), &seed, seed_size);
+ memcpy(key.get() + seed_size, input_ptr, input_item_bytes);
+
+ int64_t hash_signature = ::util::Fingerprint64(key.get(), key_bytes);
+ double running_value = static_cast<double>(hash_signature);
+ input_ptr += input_item_bytes;
+ if (weight == nullptr) {
+ score += running_value;
+ } else {
+ score += weight->data.f[i] * running_value;
+ }
+ }
+
+ return (score > 0) ? 1 : 0;
+}
+
+void SparseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input,
+ const TfLiteTensor* weight, int32_t* out_buf) {
+ int num_hash = SizeOfDimension(hash, 0);
+ int num_bits = SizeOfDimension(hash, 1);
+ for (int i = 0; i < num_hash; i++) {
+ int32_t hash_signature = 0;
+ for (int j = 0; j < num_bits; j++) {
+ float seed = hash->data.f[i * num_bits + j];
+ int bit = RunningSignBit(input, weight, seed);
+ hash_signature = (hash_signature << 1) | bit;
+ }
+ *out_buf++ = hash_signature + i * (1 << num_bits);
+ }
+}
+
+void DenseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input,
+ const TfLiteTensor* weight, int32_t* out_buf) {
+ int num_hash = SizeOfDimension(hash, 0);
+ int num_bits = SizeOfDimension(hash, 1);
+ for (int i = 0; i < num_hash; i++) {
+ for (int j = 0; j < num_bits; j++) {
+ float seed = hash->data.f[i * num_bits + j];
+ int bit = RunningSignBit(input, weight, seed);
+ *out_buf++ = bit;
+ }
+ }
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data);
+
+ int32_t* out_buf = GetOutput(context, node, 0)->data.i32;
+ TfLiteTensor* hash = GetInput(context, node, 0);
+ TfLiteTensor* input = GetInput(context, node, 1);
+ TfLiteTensor* weight =
+ NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2);
+
+ switch (params->type) {
+ case kTfLiteLshProjectionDense:
+ DenseLshProjection(hash, input, weight, out_buf);
+ break;
+ case kTfLiteLshProjectionSparse:
+ SparseLshProjection(hash, input, weight, out_buf);
+ break;
+ default:
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+} // namespace lsh_projection
+
+TfLiteRegistration* Register_LSH_PROJECTION() {
+ static TfLiteRegistration r = {nullptr, nullptr, lsh_projection::Resize,
+ lsh_projection::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/lsh_projection_test.cc b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc
new file mode 100644
index 0000000000..1011927848
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc
@@ -0,0 +1,123 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+class LSHProjectionOpModel : public SingleOpModel {
+ public:
+ LSHProjectionOpModel(LSHProjectionType type,
+ std::initializer_list<int> hash_shape,
+ std::initializer_list<int> input_shape,
+ std::initializer_list<int> weight_shape) {
+ hash_ = AddInput(TensorType_FLOAT32);
+ input_ = AddInput(TensorType_INT32);
+ if (weight_shape.size() > 0) {
+ weight_ = AddInput(TensorType_FLOAT32);
+ }
+ output_ = AddOutput(TensorType_INT32);
+
+ SetBuiltinOp(BuiltinOperator_LSH_PROJECTION,
+ BuiltinOptions_LSHProjectionOptions,
+ CreateLSHProjectionOptions(builder_, type).Union());
+ if (weight_shape.size() > 0) {
+ BuildInterpreter({hash_shape, input_shape, weight_shape});
+ } else {
+ BuildInterpreter({hash_shape, input_shape});
+ }
+
+ output_size_ = 1;
+ for (int i : hash_shape) {
+ output_size_ *= i;
+ if (type == LSHProjectionType_SPARSE) {
+ break;
+ }
+ }
+ }
+ void SetInput(std::initializer_list<int> data) {
+ PopulateTensor(input_, data);
+ }
+
+ void SetHash(std::initializer_list<float> data) {
+ PopulateTensor(hash_, data);
+ }
+
+ void SetWeight(std::initializer_list<float> f) { PopulateTensor(weight_, f); }
+
+ std::vector<int> GetOutput() { return ExtractVector<int>(output_); }
+
+ private:
+ int input_;
+ int hash_;
+ int weight_;
+ int output_;
+
+ int output_size_;
+};
+
+TEST(LSHProjectionOpTest2, Dense1DInputs) {
+ LSHProjectionOpModel m(LSHProjectionType_DENSE, {3, 2}, {5}, {5});
+
+ m.SetInput({12345, 54321, 67890, 9876, -12345678});
+ m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
+ m.SetWeight({1.0, 1.0, 1.0, 1.0, 1.0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0));
+}
+
+TEST(LSHProjectionOpTest2, Sparse1DInputs) {
+ LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5}, {});
+
+ m.SetInput({12345, 54321, 67890, 9876, -12345678});
+ m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0));
+}
+
+TEST(LSHProjectionOpTest2, Sparse3DInputs) {
+ LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5, 2, 2}, {5});
+
+ m.SetInput({1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912,
+ 9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543});
+ m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
+ m.SetWeight({0.12, 0.34, 0.56, 0.67, 0.78});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
new file mode 100644
index 0000000000..6c06264d84
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -0,0 +1,515 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace lstm {
+
+// Input Tensors of size {n_batch, n_input}
+constexpr int kInputTensor = 0;
+
+// Input weight tensors of size: {n_cell, n_input}
+constexpr int kInputToInputWeightsTensor = 1; // Optional
+constexpr int kInputToForgetWeightsTensor = 2;
+constexpr int kInputToCellWeightsTensor = 3;
+constexpr int kInputToOutputWeightsTensor = 4;
+
+// Recurrent weight tensors of size {n_cell, n_output}
+constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
+constexpr int kRecurrentToForgetWeightsTensor = 6;
+constexpr int kRecurrentToCellWeightsTensor = 7;
+constexpr int kRecurrentToOutputWeightsTensor = 8;
+
+// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
+constexpr int kCellToInputWeightsTensor = 9; // Optional
+constexpr int kCellToForgetWeightsTensor = 10; // Optional
+constexpr int kCellToOutputWeightsTensor = 11; // Optional
+
+// Gates bias tensors of size {n_cell}
+constexpr int kInputGateBiasTensor = 12; // Optional
+constexpr int kForgetGateBiasTensor = 13;
+constexpr int kCellGateBiasTensor = 14;
+constexpr int kOutputGateBiasTensor = 15;
+
+// Projection weight tensor of size {n_output, n_cell}
+constexpr int kProjectionWeightsTensor = 16; // Optional
+// Projection bias tensor of size {n_output}
+constexpr int kProjectionBiasTensor = 17; // Optional
+
+// Output tensors.
+constexpr int kScratchBufferTensor = 0;
+constexpr int kOutputStateTensor = 1;
+constexpr int kCellStateTensor = 2;
+constexpr int kOutputTensor = 3;
+
+// Check that input tensor dimensions matches with each other.
+TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
+ TfLiteNode* node, int n_input,
+ int n_output, int n_cell) {
+ auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+
+ // Making sure clipping parameters have valid values.
+ // == 0 means no clipping
+ // > 0 means clipping
+ TF_LITE_ENSURE(context, params->cell_clip >= 0);
+ TF_LITE_ENSURE(context, params->proj_clip >= 0);
+
+ TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ if (input_to_input_weights) {
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
+ }
+
+ TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
+
+ TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
+
+ TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ if (recurrent_to_input_weights) {
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
+ n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
+ n_output);
+ }
+
+ TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
+ n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
+ n_output);
+
+ TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
+ n_output);
+
+ // We make sure the input-gate's parameters are either both present (regular
+ // LSTM) or not at all (CIFG-LSTM).
+ const bool cifg_weights_all_or_none =
+ ((input_to_input_weights != nullptr) &&
+ (recurrent_to_input_weights != nullptr)) ||
+ ((input_to_input_weights == nullptr) &&
+ (recurrent_to_input_weights == nullptr));
+ TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
+
+ TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ if (cell_to_input_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
+ }
+
+ TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ if (cell_to_forget_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
+ }
+
+ TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+ if (cell_to_output_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
+ }
+
+ // Making sure the peephole weights are there all or none.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool peephole_weights_all_or_none =
+ ((cell_to_input_weights != nullptr || use_cifg) &&
+ (cell_to_forget_weights != nullptr) &&
+ (cell_to_output_weights != nullptr)) ||
+ ((cell_to_input_weights == nullptr) &&
+ (cell_to_forget_weights == nullptr) &&
+ (cell_to_output_weights == nullptr));
+ TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
+
+ // Make sure the input gate bias is present only when not a CIFG-LSTM.
+ TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ if (use_cifg) {
+ TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
+ } else {
+ TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
+ }
+
+ TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
+
+ TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
+
+ TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
+
+ TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ if (projection_weights) {
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
+ }
+
+ TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+ if (projection_bias) {
+ TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
+ }
+
+ // Making sure the projection tensors are consistent:
+ // 1) If projection weight is not present, then projection bias should not be
+ // present.
+ // 2) If projection weight is present, then projection bias is optional.
+ // TODO(ghodrat): make sure this is correct.
+ const bool projecton_tensors_consistent =
+ ((projection_weights != nullptr) || (projection_bias == nullptr));
+ TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
+
+ return kTfLiteOk;
+}
+
+// Resize the output, state and scratch tensors based on the sizes of the input
+// tensors. Also check that the size of the input tensors match each other.
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // Check we have all the inputs and outputs we need.
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
+
+ // Inferring batch size, number of outputs and number of cells from the
+ // input tensors.
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE(context, input->dims->size > 1);
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+
+ TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+ const int n_cell = input_to_output_weights->dims->data[0];
+ TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
+
+ TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
+ n_cell);
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Check that input tensor dimensions matches with each other.
+ CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
+
+ // Get the pointer to output, state and scratch buffer tensors.
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
+ TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ // TODO(ghodrat): Modify this as soon as we have a finalized method for
+ // scratch buffers.
+ TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
+
+ // Resize the output and output_state tensors.
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
+ output_size->data[0] = n_batch;
+ output_size->data[1] = n_output;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size));
+
+ TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2);
+ output_state_size->data[0] = n_batch;
+ output_state_size->data[1] = n_output;
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, output_state, output_state_size));
+
+ // Resize the output, state and scratch buffer tensors.
+ TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
+ cell_size->data[0] = n_batch;
+ cell_size->data[1] = n_cell;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state, cell_size));
+
+ // Mark state tensors as persistent tensors.
+ output_state->allocation_type = kTfLiteArenaRwPersistent;
+ cell_state->allocation_type = kTfLiteArenaRwPersistent;
+
+ TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ if (use_cifg) {
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
+ // Reserving space for Cell, Forget, Output gates
+ scratch_buffer_size->data[1] = n_cell * 3;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+ } else {
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
+ // Reserving space for Input, Cell, Forget, Output gates
+ scratch_buffer_size->data[1] = n_cell * 4;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+ }
+ return kTfLiteOk;
+}
+
+// The LSTM Op engine.
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+ TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+
+ TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+
+ TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+
+ TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+
+ TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+
+ TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
+ TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Initialize scratch buffers with bias.
+ if (!use_cifg) {
+ tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell,
+ n_batch, input_gate_scratch);
+ }
+ tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell,
+ n_batch, forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell,
+ n_batch, output_gate_scratch);
+
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights->data.f, n_cell, n_input, input->data.f, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights->data.f, n_cell, n_input, input->data.f, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights->data.f, n_cell, n_input, input->data.f, n_batch,
+ cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights->data.f, n_cell, n_input, input->data.f, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights->data.f, n_cell, n_output,
+ output_state->data.f, n_batch, input_gate_scratch, /*result_stride=*/1);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights->data.f, n_cell, n_output,
+ output_state->data.f, n_batch, forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights->data.f, n_cell, n_output, output_state->data.f,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights->data.f, n_cell, n_output,
+ output_state->data.f, n_batch, output_gate_scratch, /*result_stride=*/1);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
+ cell_state->data.f, n_batch * n_cell,
+ cell_state->data.f);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ params->activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell,
+ cell_state->data.f);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state->data.f);
+ }
+ if (params->cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell,
+ params->cell_clip, cell_state->data.f);
+ }
+
+ // For each batch and cell: update the output gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell,
+ params->activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights != nullptr);
+ const bool use_projection_bias = (projection_bias != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output,
+ n_batch, output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, n_batch * n_output);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights->data.f, n_output, n_cell, output_gate_scratch,
+ n_batch, output->data.f, /*result_stride=*/1);
+ if (params->proj_clip > 0.0) {
+ tensor_utils::ClipVector(output->data.f, n_batch * n_output,
+ params->proj_clip, output->data.f);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output->data.f);
+ }
+ tensor_utils::CopyVector(output->data.f, n_batch * n_output,
+ output_state->data.f);
+
+ return kTfLiteOk;
+}
+
+} // namespace lstm
+
+TfLiteRegistration* Register_LSTM() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ lstm::Prepare, lstm::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc
new file mode 100644
index 0000000000..be4c7ddbf8
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lstm_test.cc
@@ -0,0 +1,1088 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite LSTM op.
+
+#include <iomanip>
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class LSTMOpModel : public SingleOpModel {
+ public:
+ LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
+ bool use_peephole, bool use_projection_weights,
+ bool use_projection_bias, float cell_clip, float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes)
+ : n_batch_(n_batch),
+ n_input_(n_input),
+ n_cell_(n_cell),
+ n_output_(n_output) {
+ input_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ input_to_input_weights_ = AddNullInput();
+ } else {
+ input_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+
+ input_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_cell_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_output_weights_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ recurrent_to_input_weights_ = AddNullInput();
+ } else {
+ recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+
+ recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32);
+
+ if (use_peephole) {
+ if (use_cifg) {
+ cell_to_input_weights_ = AddNullInput();
+ } else {
+ cell_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+ cell_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ } else {
+ cell_to_input_weights_ = AddNullInput();
+ cell_to_forget_weights_ = AddNullInput();
+ cell_to_output_weights_ = AddNullInput();
+ }
+
+ if (use_cifg) {
+ input_gate_bias_ = AddNullInput();
+ } else {
+ input_gate_bias_ = AddInput(TensorType_FLOAT32);
+ }
+ forget_gate_bias_ = AddInput(TensorType_FLOAT32);
+ cell_bias_ = AddInput(TensorType_FLOAT32);
+ output_gate_bias_ = AddInput(TensorType_FLOAT32);
+
+ if (use_projection_weights) {
+ projection_weights_ = AddInput(TensorType_FLOAT32);
+ if (use_projection_bias) {
+ projection_bias_ = AddInput(TensorType_FLOAT32);
+ } else {
+ projection_bias_ = AddNullInput();
+ }
+ } else {
+ projection_weights_ = AddNullInput();
+ projection_bias_ = AddNullInput();
+ }
+
+ scratch_buffer_ = AddOutput(TensorType_FLOAT32);
+ // TODO(ghodrat): Modify these states when we have a permanent solution for
+ // persistent buffer.
+ output_state_ = AddOutput(TensorType_FLOAT32);
+ cell_state_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+
+ SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
+ CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
+ cell_clip, proj_clip)
+ .Union());
+ BuildInterpreter(input_shapes);
+ }
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_output_weights_, f);
+ }
+
+ void SetInputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(input_gate_bias_, f);
+ }
+
+ void SetForgetGateBias(std::initializer_list<float> f) {
+ PopulateTensor(forget_gate_bias_, f);
+ }
+
+ void SetCellBias(std::initializer_list<float> f) {
+ PopulateTensor(cell_bias_, f);
+ }
+
+ void SetOutputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(output_gate_bias_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ PopulateTensor(projection_weights_, f);
+ }
+
+ void SetProjectionBias(std::initializer_list<float> f) {
+ PopulateTensor(projection_bias_, f);
+ }
+
+ void ResetOutputState() {
+ const int zero_buffer_size = n_cell_ * n_batch_;
+ std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
+ memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
+ PopulateTensor(output_state_, 0, zero_buffer.get(),
+ zero_buffer.get() + zero_buffer_size);
+ }
+
+ void ResetCellState() {
+ const int zero_buffer_size = n_cell_ * n_batch_;
+ std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
+ memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
+ PopulateTensor(cell_state_, 0, zero_buffer.get(),
+ zero_buffer.get() + zero_buffer_size);
+ }
+
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ int num_inputs() { return n_input_; }
+ int num_outputs() { return n_output_; }
+ int num_cells() { return n_cell_; }
+ int num_batches() { return n_batch_; }
+
+ private:
+ int input_;
+ int input_to_input_weights_;
+ int input_to_forget_weights_;
+ int input_to_cell_weights_;
+ int input_to_output_weights_;
+
+ int recurrent_to_input_weights_;
+ int recurrent_to_forget_weights_;
+ int recurrent_to_cell_weights_;
+ int recurrent_to_output_weights_;
+
+ int cell_to_input_weights_;
+ int cell_to_forget_weights_;
+ int cell_to_output_weights_;
+
+ int input_gate_bias_;
+ int forget_gate_bias_;
+ int cell_bias_;
+ int output_gate_bias_;
+
+ int projection_weights_;
+ int projection_bias_;
+
+ int output_;
+ int output_state_;
+ int cell_state_;
+ int scratch_buffer_;
+
+ int n_batch_;
+ int n_input_;
+ int n_cell_;
+ int n_output_;
+};
+
+TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+
+ LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/false,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524});
+
+ lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113,
+ -0.29909778});
+
+ lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212});
+
+ lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077, -0.1556896,
+ 0.19487578});
+
+ lstm.SetInputGateBias({0., 0., 0., 0.});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
+ -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
+ -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+
+ lstm.SetRecurrentToCellWeights(
+ {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
+ -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
+ -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
+ 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+
+ static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
+ static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126,
+ -0.15358765, -0.03716109, 0.12507336,
+ 0.41193449, -0.20860538, -0.15053082,
+ 0.09120187, 0.24278517, -0.12222792};
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ const int input_sequence_size =
+ sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs());
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch0_start = lstm_input + i * lstm.num_inputs();
+ float* batch0_end = batch0_start + lstm.num_inputs();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ lstm.Invoke();
+
+ float* golden_start = lstm_golden_output + i * lstm.num_outputs();
+ float* golden_end = golden_start + lstm.num_outputs();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
+}
+
+TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+
+ LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/true, /*use_peephole=*/true,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
+ 0.04717243, 0.48944736, -0.38535351,
+ -0.17212132});
+
+ lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
+ -0.3633365, -0.22755712, 0.28253698, 0.24407166,
+ 0.33826375});
+
+ lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
+ -0.09426838, -0.44257352, 0.54939759,
+ 0.01533556, 0.42751634});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToCellWeights(
+ {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
+ 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
+ 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
+ 0.21193194});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
+ 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
+ -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
+ -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
+ 0.50248802, 0.26114327, -0.43736315, 0.33149987});
+
+ lstm.SetCellToForgetWeights(
+ {0.47485286, -0.51955009, -0.24458408, 0.31544167});
+ lstm.SetCellToOutputWeights(
+ {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+
+ static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
+ static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585,
+ -0.05163646, -0.42312205, -0.01218222,
+ 0.24201041, -0.08124574, -0.358325,
+ -0.04621704, 0.21641694, -0.06471302};
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ const int input_sequence_size =
+ sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs());
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch0_start = lstm_input + i * lstm.num_inputs();
+ float* batch0_end = batch0_start + lstm.num_inputs();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ lstm.Invoke();
+
+ float* golden_start = lstm_golden_output + i * lstm.num_outputs();
+ float* golden_end = golden_start + lstm.num_outputs();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
+}
+
+TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 20;
+ const int n_output = 16;
+
+ LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToInputWeights(
+ {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
+ 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
+ -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
+ -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
+ -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
+ -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
+ -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
+ 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
+ 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
+ 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
+ -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
+ 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
+ -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
+ -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
+ -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
+ 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
+ -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
+ -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
+ -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
+ -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677});
+
+ lstm.SetInputToForgetWeights(
+ {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236,
+ -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505,
+ -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495,
+ 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
+ 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421,
+ -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887,
+ -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791,
+ 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
+ 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068,
+ 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905,
+ 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605,
+ -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464,
+ 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506,
+ -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063,
+ -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375,
+ 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553,
+ 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353,
+ 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717,
+ -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371,
+ 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496});
+
+ lstm.SetInputToCellWeights(
+ {-0.04580283, -0.09549462, -0.032418985, -0.06454633,
+ -0.043528453, 0.043018587, -0.049152344, -0.12418144,
+ -0.078985475, -0.07596889, 0.019484362, -0.11434962,
+ -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
+ -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
+ 0.10665918, -0.032036792, -0.08505916, -0.10843358,
+ -0.13002433, -0.036816437, -0.02130134, -0.016518239,
+ 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
+ -0.10652836, -0.1037554, -0.13056071, -0.03266643,
+ -0.033702414, -0.006473424, -0.04611692, 0.014419339,
+ -0.025174323, 0.0396852, 0.081777506, 0.06157468,
+ 0.10210095, -0.009658194, 0.046511717, 0.03603906,
+ 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
+ 0.053568836, 0.06408714, 0.12835667, -0.008714329,
+ -0.20211966, -0.12093674, 0.029450472, 0.2849013,
+ -0.029227901, 0.1164364, -0.08560263, 0.09941786,
+ -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
+ -0.09720865, -0.11193351, -0.029155117, -0.017936034,
+ -0.009768936, -0.04223324, -0.036159635, 0.06505112,
+ -0.021742892, -0.023377212, -0.07221364, -0.06430552,
+ 0.05453865, 0.091149814, 0.06387331, 0.007518393,
+ 0.055960953, 0.069779344, 0.046411168, 0.10509911,
+ 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
+ 0.056955688, 0.06555285, 0.050801456, -0.009862683,
+ 0.00826772, -0.026555609, -0.0073611983, -0.0014897042});
+
+ lstm.SetInputToOutputWeights(
+ {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
+ -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
+ 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
+ -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
+ -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
+ 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
+ -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
+ -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
+ -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
+ -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
+ 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
+ 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
+ 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
+ -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
+ 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
+ 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
+ -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
+ 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
+ -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
+ -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956});
+
+ lstm.SetInputGateBias(
+ {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216,
+ -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339,
+ -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818,
+ 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196});
+
+ lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696,
+ 0.11098921, 0.15378423, 0.09263801, 0.09790885,
+ 0.09508917, 0.061199076, 0.07665568, -0.015443159,
+ -0.03499149, 0.046190713, 0.08895977, 0.10899629,
+ 0.40694186, 0.06030037, 0.012413437, -0.06108739});
+
+ lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
+ -0.1483596, -0.10639995, -0.091433935, 0.058573797,
+ -0.06809782, -0.07889636, -0.043246906, -0.09829136,
+ -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
+ 0.016178843, 0.1749513, 0.13975595, 0.92058027});
+
+ lstm.SetOutputGateBias(
+ {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795,
+ 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895,
+ 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149,
+ -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.001374326, -0.078856036, 0.10672688, 0.029162422,
+ -0.11585556, 0.02557986, -0.13446963, -0.035785314,
+ -0.01244275, 0.025961924, -0.02337298, -0.044228926,
+ -0.055839065, -0.046598054, -0.010546039, -0.06900766,
+ 0.027239809, 0.022582639, -0.013296484, -0.05459212,
+ 0.08981, -0.045407712, 0.08682226, -0.06867011,
+ -0.14390695, -0.02916037, 0.000996957, 0.091420636,
+ 0.14283475, -0.07390571, -0.06402044, 0.062524505,
+ -0.093129106, 0.04860203, -0.08364217, -0.08119002,
+ 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
+ -0.13732095, 0.012405723, -0.07551853, 0.06343048,
+ 0.12162708, -0.031923793, -0.014335606, 0.01790974,
+ -0.10650317, -0.0724401, 0.08554849, -0.05727212,
+ 0.06556731, -0.042729504, -0.043227166, 0.011683251,
+ -0.013082158, -0.029302018, -0.010899579, -0.062036745,
+ -0.022509435, -0.00964907, -0.01567329, 0.04260106,
+ -0.07787477, -0.11576462, 0.017356863, 0.048673786,
+ -0.017577527, -0.05527947, -0.082487635, -0.040137455,
+ -0.10820036, -0.04666372, 0.022746278, -0.07851417,
+ 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
+ 0.08944216, -0.0685835, 0.010513544, 0.07228705,
+ 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
+ 0.040414046, -0.1380399, 0.094208956, -0.05722982,
+ 0.012092817, -0.04989123, -0.086576, -0.003399834,
+ -0.04696032, -0.045747425, 0.10091314, 0.048676282,
+ -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
+ 0.09504992, 0.041799378, -0.049185462, -0.031518843,
+ -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
+ -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
+ -0.10167381, 0.042500053, -0.01447153, 0.06464186,
+ -0.017142897, 0.03312627, 0.009205989, 0.024138335,
+ -0.011337001, 0.035530265, -0.010912711, 0.0706555,
+ -0.005894094, 0.051841937, -0.1401738, -0.02351249,
+ 0.0365468, 0.07590991, 0.08838724, 0.021681072,
+ -0.10086113, 0.019608743, -0.06195883, 0.077335775,
+ 0.023646897, -0.095322326, 0.02233014, 0.09756986,
+ -0.048691444, -0.009579111, 0.07595467, 0.11480546,
+ -0.09801813, 0.019894179, 0.08502348, 0.004032281,
+ 0.037211012, 0.068537936, -0.048005626, -0.091520436,
+ -0.028379958, -0.01556313, 0.06554592, -0.045599163,
+ -0.01672207, -0.020169014, -0.011877351, -0.20212261,
+ 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
+ -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
+ 0.015963363, 0.00871737, 0.060130805, 0.028611384,
+ 0.10109069, -0.015060172, -0.07894427, 0.06401885,
+ 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
+ 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
+ 0.019899689, 0.006106124, -0.027092824, 0.0786356,
+ 0.05052217, -0.058925, -0.011402121, -0.024987547,
+ -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
+ -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
+ -0.033664223, -0.07978348, -0.025200296, -0.017207067,
+ -0.058403496, -0.055697463, 0.005798788, 0.12965427,
+ -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
+ 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
+ 0.013806405, -0.017858358, -0.01008298, -0.07700066,
+ -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
+ 0.062634714, -0.02338735, -0.039547626, -0.02050681,
+ 0.03385117, -0.083611414, 0.002862572, -0.09421313,
+ 0.058618143, -0.08598433, 0.00972939, 0.023867095,
+ -0.053934585, -0.023203006, 0.07452513, -0.048767887,
+ -0.07314807, -0.056307215, -0.10433547, -0.06440842,
+ 0.04328182, 0.04389765, -0.020006588, -0.09076438,
+ -0.11652589, -0.021705797, 0.03345259, -0.010329105,
+ -0.025767034, 0.013057034, -0.07316461, -0.10145612,
+ 0.06358255, 0.18531723, 0.07759293, 0.12006465,
+ 0.1305557, 0.058638252, -0.03393652, 0.09622831,
+ -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
+ -0.005644518, 0.06857898, -0.12598175, -0.035084512,
+ 0.03156317, -0.12794146, -0.031963028, 0.04692781,
+ 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
+ 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
+ 0.08184801, -0.019164372, 0.06791302, 0.034257166,
+ -0.10307039, 0.021943003, 0.046745934, 0.0790918,
+ -0.0265588, -0.007824208, 0.042546265, -0.00977924,
+ -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
+ -0.014512694, -0.08251313, 0.08861942, 0.13589665,
+ 0.026351685, 0.012641483, 0.07466548, 0.044301085,
+ -0.045414884, -0.051112458, 0.03444247, -0.08502782,
+ -0.04106223, -0.028126027, 0.028473156, 0.10467447});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.057784554, -0.026057621, -0.068447545, -0.022581743,
+ 0.14811787, 0.10826372, 0.09471067, 0.03987225,
+ -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
+ 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
+ 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
+ -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
+ -0.06193199, 0.055729095, 0.03736828, 0.020123724,
+ 0.061878487, -0.04729229, 0.034919553, -0.07585433,
+ -0.04421272, -0.044019096, 0.085488975, 0.04058006,
+ -0.06890133, -0.030951202, -0.024628663, -0.07672815,
+ 0.034293607, 0.08556707, -0.05293577, -0.033561368,
+ -0.04899627, 0.0241671, 0.015736353, -0.095442444,
+ -0.029564252, 0.016493602, -0.035026584, 0.022337519,
+ -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
+ 0.016435321, -0.03263031, -0.09543275, -0.047392778,
+ 0.013454138, 0.028934088, 0.01685226, -0.086110644,
+ -0.046250615, -0.01847454, 0.047608484, 0.07339695,
+ 0.034546845, -0.04881143, 0.009128804, -0.08802852,
+ 0.03761666, 0.008096139, -0.014454086, 0.014361001,
+ -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
+ -0.06509276, -0.006021153, -0.08570962, -0.1451793,
+ 0.060212336, 0.055259194, 0.06974018, 0.049454916,
+ -0.027794661, -0.08077226, -0.016179763, 0.1169753,
+ 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
+ -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
+ 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
+ -0.05695512, 0.047233116, 0.038937137, -0.06542224,
+ 0.014429736, -0.09719407, 0.13908425, -0.05379757,
+ 0.012321099, 0.082840554, -0.029899208, 0.044217527,
+ 0.059855383, 0.07711018, -0.045319796, 0.0948846,
+ -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
+ -0.13873616, 0.040668588, 0.034832682, -0.015319203,
+ -0.018715994, 0.046002675, 0.0599172, -0.043107376,
+ 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
+ 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
+ 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
+ 0.052958444, 0.07558703, 0.04817258, 0.044462286,
+ -0.015213451, -0.08783778, -0.0561384, -0.003008196,
+ 0.047060397, -0.002058388, 0.03429439, -0.018839769,
+ 0.024734668, 0.024614193, -0.042046934, 0.09597743,
+ -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
+ -0.02558259, -0.022822596, -0.023273505, -0.02464396,
+ -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
+ 0.04383914, -0.046476185, 0.028658995, 0.060410924,
+ 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
+ 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
+ 0.015898481, 0.021362653, -0.030262267, 0.016587038,
+ -0.011442813, 0.041154444, -0.007631438, -0.03423484,
+ -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
+ 0.02318443, -0.041350313, 0.021485701, -0.10906167,
+ -0.028218046, -0.00954771, 0.020531068, -0.11995105,
+ -0.03672871, 0.024019798, 0.014255957, -0.05221243,
+ -0.00661567, -0.04630967, 0.033188973, 0.10107534,
+ -0.014027541, 0.030796422, -0.10270911, -0.035999842,
+ 0.15443139, 0.07684145, 0.036571592, -0.035900835,
+ -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
+ -0.03858649, 0.01849943, 0.13872518, 0.01503974,
+ 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
+ -0.047401894, 0.03100163, -0.041533746, -0.10430945,
+ 0.044574402, -0.01425562, -0.024290353, 0.034563623,
+ 0.05866852, 0.023947537, -0.09445152, 0.035450947,
+ 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
+ 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
+ 0.03532124, -0.016341697, 0.09685456, -0.016764693,
+ 0.051808182, 0.05875331, -0.04536488, 0.001626336,
+ -0.028892258, -0.01048663, -0.009793449, -0.017093895,
+ 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
+ -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
+ -0.01769146, 0.040995963, 0.02235177, -0.060430344,
+ 0.11475477, -0.023854522, 0.10071741, 0.0686208,
+ -0.014250481, 0.034261297, 0.047418304, 0.08562733,
+ -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
+ 0.04096551, 0.032249358, -0.08355519, -0.026823482,
+ 0.056386515, -0.010401743, -0.028396193, 0.08507674,
+ 0.014410365, 0.020995233, 0.17040324, 0.11511526,
+ 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
+ -0.081302024, 0.017264642, -0.009585969, 0.09491168,
+ -0.051313367, 0.054532815, -0.014298593, 0.10657464,
+ 0.007076659, 0.10964551, 0.0409152, 0.008275321,
+ -0.07283536, 0.07937492, 0.04192024, -0.1075027});
+
+ lstm.SetRecurrentToCellWeights(
+ {-0.037322544, 0.018592842, 0.0056175636, -0.06253426,
+ 0.055647098, -0.05713207, -0.05626563, 0.005559383,
+ 0.03375411, -0.025757805, -0.088049285, 0.06017052,
+ -0.06570978, 0.007384076, 0.035123326, -0.07920549,
+ 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
+ 0.08089997, 0.05143358, 0.038261272, 0.03339287,
+ -0.027673481, 0.044746667, 0.028349208, 0.020090483,
+ -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
+ -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
+ -0.10893326, 0.076739706, -0.08509834, -0.027997585,
+ 0.037871376, 0.01449768, -0.09002357, -0.06111149,
+ -0.046195522, 0.0422062, -0.005683705, -0.1253618,
+ -0.012925729, -0.04890792, 0.06985068, 0.037654128,
+ 0.03398274, -0.004781977, 0.007032333, -0.031787455,
+ 0.010868644, -0.031489216, 0.09525667, 0.013939797,
+ 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
+ -0.048885044, -0.12722108, 0.035304096, 0.06554885,
+ 0.00972396, -0.039238118, -0.05159735, -0.11329045,
+ 0.1613692, -0.03750952, 0.06529313, -0.071974665,
+ -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
+ 0.02786344, -0.014179351, 0.005264273, 0.14376344,
+ 0.015983658, 0.03406988, -0.06939408, 0.040699873,
+ 0.02111075, 0.09669095, 0.041345075, -0.08316494,
+ -0.07684199, -0.045768797, 0.032298047, -0.041805092,
+ 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
+ -0.024950314, 0.11574242, 0.04508852, -0.04335324,
+ 0.06760663, -0.027437469, 0.07216407, 0.06977076,
+ -0.05438599, 0.034033038, -0.028602652, 0.05346137,
+ 0.043184172, -0.037189785, 0.10420091, 0.00882477,
+ -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
+ 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
+ 0.04361412, -0.007001822, 0.09631092, -0.06702025,
+ -0.042049985, -0.035070654, -0.04103342, -0.10273396,
+ 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
+ -0.008264958, 0.042035464, 0.05891794, 0.029673764,
+ 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
+ -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
+ -0.04043371, -0.017094059, 0.07229206, -0.023670016,
+ -0.052195564, -0.025616996, -0.01520939, 0.045104615,
+ -0.007376126, 0.003533447, 0.006570588, 0.056037236,
+ 0.12436656, 0.051817212, 0.028532185, -0.08686856,
+ 0.11868599, 0.07663395, -0.07323171, 0.03463402,
+ -0.050708205, -0.04458982, -0.11590894, 0.021273347,
+ 0.1251325, -0.15313013, -0.12224372, 0.17228661,
+ 0.023029093, 0.086124025, 0.006445803, -0.03496501,
+ 0.028332196, 0.04449512, -0.042436164, -0.026587414,
+ -0.006041347, -0.09292539, -0.05678812, 0.03897832,
+ 0.09465633, 0.008115513, -0.02171956, 0.08304309,
+ 0.071401566, 0.019622514, 0.032163795, -0.004167056,
+ 0.02295182, 0.030739572, 0.056506045, 0.004612461,
+ 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
+ -0.1335546, -0.030136576, 0.11584653, -0.014678886,
+ 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
+ -0.0329582, 0.07922767, 0.029322514, 0.026405897,
+ 0.04207835, -0.07073373, 0.063781224, 0.0859677,
+ -0.10925287, -0.07011058, 0.048005477, 0.03438226,
+ -0.09606514, -0.006669445, -0.043381985, 0.04240257,
+ -0.06955775, -0.06769346, 0.043903265, -0.026784198,
+ -0.017840602, 0.024307009, -0.040079936, -0.019946516,
+ 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
+ 0.15978073, 0.10185836, 0.10298046, -0.015476589,
+ -0.039390966, -0.072174534, 0.0739445, -0.1211869,
+ -0.0347889, -0.07943156, 0.014809798, -0.12412325,
+ -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
+ -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
+ -0.01514876, -0.056505352, -0.012800942, -0.06994386,
+ 0.012962922, -0.031234352, 0.07029052, 0.016418684,
+ 0.03618972, 0.055686004, -0.08663945, -0.017404709,
+ -0.054761406, 0.029065743, 0.052404847, 0.020238016,
+ 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
+ 0.06262858, 0.009184685, 0.020785125, -0.043904778,
+ -0.0270329, -0.03299152, -0.060088247, -0.015162964,
+ -0.001828936, 0.12642565, -0.056757294, 0.013586685,
+ 0.09232601, -0.035886683, 0.06000002, 0.05229691,
+ -0.052580316, -0.082029596, -0.010794592, 0.012947712,
+ -0.036429964, -0.085508935, -0.13127148, -0.017744139,
+ 0.031502828, 0.036232427, -0.031581745, 0.023051167,
+ -0.05325106, -0.03421577, 0.028793324, -0.034633752,
+ -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
+ -0.008799762, 0.056595087, 0.0022273948, 0.055752404});
+
+ lstm.SetRecurrentToOutputWeights({
+ 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415,
+ -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349,
+ -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948,
+ -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774,
+ -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125,
+ -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224,
+ -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088,
+ 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
+ -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728,
+ 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607,
+ -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928,
+ -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462,
+ 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879,
+ 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698,
+ -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146,
+ 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345,
+ 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166,
+ 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203,
+ 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743,
+ 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415,
+ -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618,
+ 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891,
+ -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015,
+ 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109,
+ 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886,
+ 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396,
+ -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282,
+ -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025,
+ -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575,
+ -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277,
+ -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719,
+ -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
+ 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483,
+ 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102,
+ -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775,
+ 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
+ -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656,
+ -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286,
+ -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309,
+ 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545,
+ 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754,
+ 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831,
+ -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697,
+ 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
+ -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222,
+ -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989,
+ -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827,
+ -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949,
+ 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819,
+ -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954,
+ 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228,
+ -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001,
+ -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939,
+ -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556,
+ -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718,
+ 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
+ 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974,
+ -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485,
+ 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856,
+ 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853,
+ -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019,
+ 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024,
+ 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994,
+ 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621,
+ });
+
+ lstm.SetCellToInputWeights(
+ {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
+ -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
+ -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
+ 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175});
+
+ lstm.SetCellToForgetWeights(
+ {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
+ -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
+ -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
+ 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355});
+
+ lstm.SetCellToOutputWeights(
+ {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
+ -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
+ -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
+ 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733});
+
+ lstm.SetProjectionWeights(
+ {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832,
+ 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683,
+ -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931,
+ -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476,
+ 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067,
+ 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787,
+ 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588,
+ 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285,
+ -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949,
+ -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768,
+ -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929,
+ 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504,
+ 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946,
+ 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117,
+ 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253,
+ 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456,
+ -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552,
+ 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797,
+ -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272,
+ 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165,
+ -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922,
+ -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548,
+ 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786,
+ -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722,
+ 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318,
+ -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776,
+ -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307,
+ 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969,
+ -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593,
+ -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515,
+ -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288,
+ 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723,
+ 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097,
+ -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209,
+ 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268,
+ 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139,
+ 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707,
+ 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871,
+ 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553,
+ -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702,
+ -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615,
+ 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187,
+ -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388,
+ -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
+ 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263,
+ 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777,
+ 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935,
+ -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641,
+ -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996,
+ -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318,
+ 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437,
+ -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079,
+ 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237,
+ 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415,
+ -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124,
+ -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943,
+ -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311,
+ 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013,
+ -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364,
+ -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543,
+ -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102,
+ 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906,
+ 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955,
+ 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656});
+
+ static float lstm_input[][20] = {
+ {// Batch0: 4 (input_sequence_size) * 5 (n_input)
+ 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
+ 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
+ 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
+
+ {// Batch1: 4 (input_sequence_size) * 5 (n_input)
+ 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
+ 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
+ 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
+
+ static float lstm_golden_output[][64] = {
+ {// Batch0: 4 (input_sequence_size) * 16 (n_output)
+ -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
+ -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
+ -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
+ 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
+ -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
+ -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
+ 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
+ 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
+ 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
+ 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
+ -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
+ -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
+ 0.0286833, 0.00824207, 0.0264887, 0.0305169},
+ {// Batch1: 4 (input_sequence_size) * 16 (n_output)
+ -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
+ -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
+ 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
+ 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
+ -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
+ -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
+ 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
+ 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
+ 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
+ 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
+ -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
+ -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
+ 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ const int input_sequence_size =
+ sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs());
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
+ float* batch0_end = batch0_start + lstm.num_inputs();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ float* batch1_start = lstm_input[1] + i * lstm.num_inputs();
+ float* batch1_end = batch1_start + lstm.num_inputs();
+ lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end);
+
+ lstm.Invoke();
+
+ float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs();
+ float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs();
+ float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs();
+ float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
+ expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
+ EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
new file mode 100644
index 0000000000..81c73f2523
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -0,0 +1,167 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace mul {
+
+// This file has three implementation of Mul.
+enum KernelType {
+ kReference,
+ kGenericOptimized, // Neon-free
+ kNeonOptimized,
+};
+
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2));
+ for (int i = 0; i < NumDimensions(input1); ++i) {
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i),
+ SizeOfDimension(input2, i));
+ }
+
+ TF_LITE_ENSURE_EQ(context, input1->type, output->type);
+ TF_LITE_ENSURE_EQ(context, input2->type, output->type);
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims);
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <KernelType kernel_type>
+void EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLiteMulParams* params, TfLiteTensor* input1,
+ TfLiteTensor* input2, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRangeFloat(params->activation, &output_activation_min,
+ &output_activation_max);
+#define TF_LITE_MUL(type) \
+ type::Mul(GetTensorData<float>(input1), GetTensorDims(input1), \
+ GetTensorData<float>(input2), GetTensorDims(input2), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<float>(output), GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops);
+ } else {
+ TF_LITE_MUL(optimized_ops);
+ }
+#undef TF_LITE_MUL
+}
+
+template <KernelType kernel_type>
+void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteMulParams* params, TfLiteTensor* input1,
+ TfLiteTensor* input2, TfLiteTensor* output) {
+ auto input1_offset = -input1->params.zero_point;
+ auto input2_offset = -input2->params.zero_point;
+ auto output_offset = output->params.zero_point;
+
+ int32_t output_multiplier;
+ int output_shift;
+
+ double real_multiplier =
+ input1->params.scale * input2->params.scale / output->params.scale;
+ QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier,
+ &output_shift);
+
+ int32 output_activation_min, output_activation_max;
+ CalculateActivationRangeUint8(params->activation, output,
+ &output_activation_min, &output_activation_max);
+
+#define TF_LITE_MUL(type) \
+ type::BroadcastMul(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
+ input1_offset, GetTensorData<uint8_t>(input2), \
+ GetTensorDims(input2), input2_offset, output_offset, \
+ output_multiplier, output_shift, output_activation_min, \
+ output_activation_max, GetTensorData<uint8_t>(output), \
+ GetTensorDims(output));
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops);
+ } else {
+ TF_LITE_MUL(optimized_ops);
+ }
+#undef TF_LITE_MUL
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
+
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (output->type == kTfLiteFloat32) {
+ EvalFloat<kernel_type>(context, node, params, input1, input2, output);
+ } else if (output->type == kTfLiteUInt8) {
+ EvalQuantized<kernel_type>(context, node, params, input1, input2, output);
+ } else {
+ context->ReportError(context,
+ "Mul only supports FLOAT32 and quantized UINT8 now.");
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace mul
+
+TfLiteRegistration* Register_MUL_REF() {
+ static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare,
+ mul::Eval<mul::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_MUL_GENERIC_OPT() {
+ static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare,
+ mul::Eval<mul::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_MUL_NEON_OPT() {
+ static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare,
+ mul::Eval<mul::kNeonOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_MUL() {
+#ifdef USE_NEON
+ return Register_MUL_NEON_OPT();
+#else
+ return Register_MUL_GENERIC_OPT();
+#endif
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc
new file mode 100644
index 0000000000..4b858e1f39
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/mul_test.cc
@@ -0,0 +1,127 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseMulOpModel : public SingleOpModel {
+ public:
+ BaseMulOpModel(TensorData input, TensorData output,
+ ActivationFunctionType activation_type) {
+ input1_ = AddInput(input);
+ input2_ = AddInput(input);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions,
+ CreateMulOptions(builder_, activation_type).Union());
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ protected:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+class FloatMulOpModel : public BaseMulOpModel {
+ public:
+ using BaseMulOpModel::BaseMulOpModel;
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+// For quantized Mul, the error shouldn't exceed (2*step + step^2).
+// The param min=-1.0 & max=1.0 is used in the following tests.
+// The tolerance value is ~0.0157.
+const float kQuantizedStep = 2.0 / 255.0;
+const float kQuantizedTolerance =
+ 2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep;
+
+class QuantizedMulOpModel : public BaseMulOpModel {
+ public:
+ using BaseMulOpModel::BaseMulOpModel;
+
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+TEST(FloatMulOpTest, NoActivation) {
+ FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4})));
+}
+
+TEST(FloatMulOpTest, ActivationRELU1) {
+ FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 1.0})));
+}
+
+TEST(FloatMulOpTest, VariousInputShapes) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4, 1.21, 0.2})))
+ << "With shape number " << i;
+ }
+}
+
+TEST(QuantizedMulOpTest, NoActivation) {
+ QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<uint8_t>(m.input1(), {-0.8, 0.2, 0.9, 0.7});
+ m.QuantizeAndPopulate<uint8_t>(m.input2(), {0.6, 0.4, 0.9, 0.8});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56},
+ kQuantizedTolerance)));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h
new file mode 100644
index 0000000000..7535afaf8e
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/op_macros.h
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
+
+#define TF_LITE_FATAL(msg) \
+ do { \
+ fprintf(stderr, "%s\n", (msg)); \
+ exit(1); \
+ } while (0)
+#define TF_LITE_ASSERT(x) \
+ do { \
+ if (!(x)) TF_LITE_FATAL(#x); \
+ } while (0)
+#define TF_LITE_ASSERT_EQ(x, y) \
+ do { \
+ if ((x) != (y)) TF_LITE_FATAL(#x " didn't equal " #y); \
+ } while (0)
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
new file mode 100644
index 0000000000..8977d27f73
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
@@ -0,0 +1,343 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite LSTM op.
+
+#include <iomanip>
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class LSTMOpModel : public SingleOpModel {
+ public:
+ LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
+ bool use_peephole, bool use_projection_weights,
+ bool use_projection_bias, float cell_clip, float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes)
+ : n_batch_(n_batch),
+ n_input_(n_input),
+ n_cell_(n_cell),
+ n_output_(n_output) {
+ input_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ input_to_input_weights_ = AddNullInput();
+ } else {
+ input_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+
+ input_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_cell_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_output_weights_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ recurrent_to_input_weights_ = AddNullInput();
+ } else {
+ recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+
+ recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32);
+
+ if (use_peephole) {
+ if (use_cifg) {
+ cell_to_input_weights_ = AddNullInput();
+ } else {
+ cell_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+ cell_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ } else {
+ cell_to_input_weights_ = AddNullInput();
+ cell_to_forget_weights_ = AddNullInput();
+ cell_to_output_weights_ = AddNullInput();
+ }
+
+ if (use_cifg) {
+ input_gate_bias_ = AddNullInput();
+ } else {
+ input_gate_bias_ = AddInput(TensorType_FLOAT32);
+ }
+ forget_gate_bias_ = AddInput(TensorType_FLOAT32);
+ cell_bias_ = AddInput(TensorType_FLOAT32);
+ output_gate_bias_ = AddInput(TensorType_FLOAT32);
+
+ if (use_projection_weights) {
+ projection_weights_ = AddInput(TensorType_FLOAT32);
+ if (use_projection_bias) {
+ projection_bias_ = AddInput(TensorType_FLOAT32);
+ } else {
+ projection_bias_ = AddNullInput();
+ }
+ } else {
+ projection_weights_ = AddNullInput();
+ projection_bias_ = AddNullInput();
+ }
+
+ scratch_buffer_ = AddOutput(TensorType_FLOAT32);
+ // TODO(ghodrat): Modify these states when we have a permanent solution for
+ // persistent buffer.
+ output_state_ = AddOutput(TensorType_FLOAT32);
+ cell_state_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+
+ SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
+ CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
+ cell_clip, proj_clip)
+ .Union());
+ BuildInterpreter(input_shapes);
+ }
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_output_weights_, f);
+ }
+
+ void SetInputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(input_gate_bias_, f);
+ }
+
+ void SetForgetGateBias(std::initializer_list<float> f) {
+ PopulateTensor(forget_gate_bias_, f);
+ }
+
+ void SetCellBias(std::initializer_list<float> f) {
+ PopulateTensor(cell_bias_, f);
+ }
+
+ void SetOutputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(output_gate_bias_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ PopulateTensor(projection_weights_, f);
+ }
+
+ void SetProjectionBias(std::initializer_list<float> f) {
+ PopulateTensor(projection_bias_, f);
+ }
+
+ void ResetOutputState() {
+ const int zero_buffer_size = n_cell_ * n_batch_;
+ std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
+ memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
+ PopulateTensor(output_state_, 0, zero_buffer.get(),
+ zero_buffer.get() + zero_buffer_size);
+ }
+
+ void ResetCellState() {
+ const int zero_buffer_size = n_cell_ * n_batch_;
+ std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
+ memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
+ PopulateTensor(cell_state_, 0, zero_buffer.get(),
+ zero_buffer.get() + zero_buffer_size);
+ }
+
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ void Verify() {
+ auto model = tflite::UnPackModel(builder_.GetBufferPointer());
+ EXPECT_NE(model, nullptr);
+ }
+
+ int num_inputs() { return n_input_; }
+ int num_outputs() { return n_output_; }
+ int num_cells() { return n_cell_; }
+ int num_batches() { return n_batch_; }
+
+ private:
+ int input_;
+ int input_to_input_weights_;
+ int input_to_forget_weights_;
+ int input_to_cell_weights_;
+ int input_to_output_weights_;
+
+ int recurrent_to_input_weights_;
+ int recurrent_to_forget_weights_;
+ int recurrent_to_cell_weights_;
+ int recurrent_to_output_weights_;
+
+ int cell_to_input_weights_;
+ int cell_to_forget_weights_;
+ int cell_to_output_weights_;
+
+ int input_gate_bias_;
+ int forget_gate_bias_;
+ int cell_bias_;
+ int output_gate_bias_;
+
+ int projection_weights_;
+ int projection_bias_;
+
+ int output_;
+ int output_state_;
+ int cell_state_;
+ int scratch_buffer_;
+
+ int n_batch_;
+ int n_input_;
+ int n_cell_;
+ int n_output_;
+};
+
+
+TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+
+ LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/true, /*use_peephole=*/true,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+
+ lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
+ 0.04717243, 0.48944736, -0.38535351,
+ -0.17212132});
+
+ lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
+ -0.3633365, -0.22755712, 0.28253698, 0.24407166,
+ 0.33826375});
+
+ lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
+ -0.09426838, -0.44257352, 0.54939759,
+ 0.01533556, 0.42751634});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToCellWeights(
+ {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
+ 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
+ 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
+ 0.21193194});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
+ 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
+ -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
+ -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
+ 0.50248802, 0.26114327, -0.43736315, 0.33149987});
+
+ lstm.SetCellToForgetWeights(
+ {0.47485286, -0.51955009, -0.24458408, 0.31544167});
+ lstm.SetCellToOutputWeights(
+ {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ // Verify the model by unpacking it.
+ lstm.Verify();
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h
new file mode 100644
index 0000000000..3a60274524
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/padding.h
@@ -0,0 +1,28 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
+
+namespace tflite {
+
+inline int ComputePadding(int stride, int in_size, int filter_size,
+ int out_size) {
+ int padding = ((out_size - 1) * stride + filter_size - in_size) / 2;
+ return padding > 0 ? padding : 0;
+}
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc
new file mode 100644
index 0000000000..b798801108
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/pooling.cc
@@ -0,0 +1,355 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/kernels/padding.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace pooling {
+
+// This file has two implementation of each pooling op.
+enum KernelType {
+ kReference,
+ kGenericOptimized,
+};
+
+enum PoolType {
+ kAverage,
+ kMax,
+ kL2,
+};
+
+struct OpData {
+ TfLitePaddingValues padding;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ // This is a builtin op, so we don't use the contents in 'buffer', if any.
+ // Instead, we allocate a new object to carry information from Prepare() to
+ // Eval().
+ return new OpData;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+template <PoolType pool_type>
+TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ int batches = input->dims->data[0];
+ int height = input->dims->data[1];
+ int width = input->dims->data[2];
+ int channels_out = input->dims->data[3];
+
+ // Matching GetWindowedOutputSize in TensorFlow.
+ auto padding = params->padding;
+ auto computeOutSize = [padding](int imageSize, int filterSize,
+ int stride) -> int {
+ return padding == kTfLitePaddingSame
+ ? (imageSize + stride - 1) / stride
+ : padding == kTfLitePaddingValid
+ ? (imageSize - filterSize + stride) / stride
+ : 0;
+ };
+
+ int outWidth =
+ computeOutSize(width, params->filter_width, params->stride_width);
+ int outHeight =
+ computeOutSize(height, params->filter_height, params->stride_height);
+
+ data->padding.height = ComputePadding(params->stride_height, height,
+ params->filter_height, outHeight);
+ data->padding.width = ComputePadding(params->stride_width, width,
+ params->filter_width, outWidth);
+
+ if (input->type == kTfLiteUInt8) {
+ if (pool_type == kAverage || pool_type == kMax) {
+ TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
+ TF_LITE_ENSURE_EQ(context, input->params.zero_point,
+ output->params.zero_point);
+ }
+ if (pool_type == kL2) {
+ // We currently don't have a quantized implementation of L2Pool
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ }
+ }
+
+ TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4);
+ outputSize->data[0] = batches;
+ outputSize->data[1] = outHeight;
+ outputSize->data[2] = outWidth;
+ outputSize->data[3] = channels_out;
+ return context->ResizeTensor(context, output, outputSize);
+}
+
+template <KernelType kernel_type>
+void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLitePoolParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* output) {
+ float activation_min, activation_max;
+ CalculateActivationRangeFloat(params->activation, &activation_min,
+ &activation_max);
+#define TF_LITE_AVERAGE_POOL(type) \
+ type::AveragePool( \
+ GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
+ params->stride_height, data->padding.width, data->padding.height, \
+ params->filter_width, params->filter_height, activation_min, \
+ activation_max, GetTensorData<float>(output), GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_AVERAGE_POOL(reference_ops);
+ } else {
+ TF_LITE_AVERAGE_POOL(optimized_ops);
+ }
+#undef TF_LITE_AVERAGE_POOL
+}
+
+template <KernelType kernel_type>
+void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLitePoolParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* output) {
+ int32_t activation_min;
+ int32_t activation_max;
+ CalculateActivationRangeUint8(params->activation, output, &activation_min,
+ &activation_max);
+#define TF_LITE_AVERAGE_POOL(type) \
+ type::AveragePool(GetTensorData<uint8_t>(input), GetTensorDims(input), \
+ params->stride_width, params->stride_height, \
+ data->padding.width, data->padding.height, \
+ params->filter_width, params->filter_height, \
+ activation_min, activation_max, \
+ GetTensorData<uint8_t>(output), GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_AVERAGE_POOL(reference_ops);
+ } else {
+ TF_LITE_AVERAGE_POOL(optimized_ops);
+ }
+#undef TF_LITE_AVERAGE_POOL
+}
+
+template <KernelType kernel_type>
+void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLitePoolParams* params, OpData* data, TfLiteTensor* input,
+ TfLiteTensor* output) {
+ float activation_min, activation_max;
+ CalculateActivationRangeFloat(params->activation, &activation_min,
+ &activation_max);
+#define TF_LITE_MAX_POOL(type) \
+ type::MaxPool( \
+ GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
+ params->stride_height, data->padding.width, data->padding.height, \
+ params->filter_width, params->filter_height, activation_min, \
+ activation_max, GetTensorData<float>(output), GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_MAX_POOL(reference_ops);
+ } else {
+ TF_LITE_MAX_POOL(optimized_ops);
+ }
+#undef TF_LITE_MAX_POOL
+}
+
+template <KernelType kernel_type>
+void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLitePoolParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* output) {
+ int32_t activation_min;
+ int32_t activation_max;
+ CalculateActivationRangeUint8(params->activation, output, &activation_min,
+ &activation_max);
+#define TF_LITE_MAX_POOL(type) \
+ type::MaxPool(GetTensorData<uint8_t>(input), GetTensorDims(input), \
+ params->stride_width, params->stride_height, \
+ data->padding.width, data->padding.height, \
+ params->filter_width, params->filter_height, activation_min, \
+ activation_max, GetTensorData<uint8_t>(output), \
+ GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_MAX_POOL(reference_ops);
+ } else {
+ TF_LITE_MAX_POOL(optimized_ops);
+ }
+#undef TF_LITE_MAX_POOL
+}
+
+template <KernelType kernel_type>
+void L2EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLitePoolParams* params, OpData* data, TfLiteTensor* input,
+ TfLiteTensor* output) {
+ float activation_min, activation_max;
+ CalculateActivationRangeFloat(params->activation, &activation_min,
+ &activation_max);
+#define TF_LITE_L2_POOL(type) \
+ type::L2Pool( \
+ GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
+ params->stride_height, data->padding.width, data->padding.height, \
+ params->filter_width, params->filter_height, activation_min, \
+ activation_max, GetTensorData<float>(output), GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_L2_POOL(reference_ops);
+ } else {
+ TF_LITE_L2_POOL(optimized_ops);
+ }
+#undef TF_LITE_L2_POOL
+}
+
+#undef TF_LITE_KERNEL_TYPE_DISPATCH
+
+template <KernelType kernel_type>
+TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ AverageEvalFloat<kernel_type>(context, node, params, data, input, output);
+ break;
+ case kTfLiteUInt8:
+ AverageEvalQuantized<kernel_type>(context, node, params, data, input,
+ output);
+ break;
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+template <KernelType kernel_type>
+TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ MaxEvalFloat<kernel_type>(context, node, params, data, input, output);
+ break;
+ case kTfLiteUInt8:
+ MaxEvalQuantized<kernel_type>(context, node, params, data, input, output);
+ break;
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+template <KernelType kernel_type>
+TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ L2EvalFloat<kernel_type>(context, node, params, data, input, output);
+ break;
+ case kTfLiteUInt8:
+ // We don't have a quantized implementation, so just fall through to the
+ // 'default' case.
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace pooling
+
+TfLiteRegistration* Register_AVERAGE_POOL_REF() {
+ static TfLiteRegistration r = {pooling::Init, pooling::Free,
+ pooling::GenericPrepare<pooling::kAverage>,
+ pooling::AverageEval<pooling::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_MAX_POOL_REF() {
+ static TfLiteRegistration r = {pooling::Init, pooling::Free,
+ pooling::GenericPrepare<pooling::kMax>,
+ pooling::MaxEval<pooling::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_L2_POOL_REF() {
+ static TfLiteRegistration r = {pooling::Init, pooling::Free,
+ pooling::GenericPrepare<pooling::kL2>,
+ pooling::L2Eval<pooling::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_AVERAGE_POOL_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ pooling::Init, pooling::Free, pooling::GenericPrepare<pooling::kAverage>,
+ pooling::AverageEval<pooling::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_MAX_POOL_GENERIC_OPT() {
+ static TfLiteRegistration r = {pooling::Init, pooling::Free,
+ pooling::GenericPrepare<pooling::kMax>,
+ pooling::MaxEval<pooling::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_L2_POOL_GENERIC_OPT() {
+ static TfLiteRegistration r = {pooling::Init, pooling::Free,
+ pooling::GenericPrepare<pooling::kL2>,
+ pooling::L2Eval<pooling::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_AVERAGE_POOL_2D() {
+ return Register_AVERAGE_POOL_GENERIC_OPT();
+}
+
+TfLiteRegistration* Register_MAX_POOL_2D() {
+ return Register_MAX_POOL_GENERIC_OPT();
+}
+
+TfLiteRegistration* Register_L2_POOL_2D() {
+ return Register_L2_POOL_GENERIC_OPT();
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/pooling_test.cc b/tensorflow/contrib/lite/kernels/pooling_test.cc
new file mode 100644
index 0000000000..e1b51ec7d5
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/pooling_test.cc
@@ -0,0 +1,161 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BasePoolingOpModel : public SingleOpModel {
+ public:
+ // TODO(ahentz): Also test different activation types, bias, padding types,
+ // stride values.
+ BasePoolingOpModel(BuiltinOperator type, const TensorData& input,
+ int filter_width, int filter_height,
+ const TensorData& output) {
+ input_ = AddInput(input);
+ output_ = AddOutput(output);
+
+ SetBuiltinOp(
+ type, BuiltinOptions_Pool2DOptions,
+ CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width,
+ filter_height, ActivationFunctionType_NONE)
+ .Union());
+
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+class FloatPoolingOpModel : public BasePoolingOpModel {
+ public:
+ using BasePoolingOpModel::BasePoolingOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+class QuantizedPoolingOpModel : public BasePoolingOpModel {
+ public:
+ using BasePoolingOpModel::BasePoolingOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(input_, data);
+ }
+
+ std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+TEST(FloatPoolingOpTest, AveragePool) {
+ FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
+ /*filter_width=*/2, /*filter_height=*/2,
+ /*output=*/{TensorType_FLOAT32, {}});
+ m.SetInput({
+ 0, 6, 2, 4, //
+ 3, 2, 10, 7, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75}));
+}
+
+TEST(QuantizedPoolingOpTest, AveragePool) {
+ // Choose the input ranges carefully so that the dequantized output matches
+ // the results of the float model above.
+ QuantizedPoolingOpModel m(
+ BuiltinOperator_AVERAGE_POOL_2D,
+ /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 0, 15.9375},
+ /*filter_width=*/2, /*filter_height=*/2,
+ /*output=*/{TensorType_UINT8, {}, 0, 15.9375});
+ m.SetInput({
+ 0, 6, 2, 4, //
+ 3, 2, 10, 7, //
+ });
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({2.75, 5.75})));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({44, 92}));
+}
+
+TEST(FloatPoolingOpTest, MaxPool) {
+ FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
+ /*filter_width=*/2, /*filter_height=*/2,
+ /*output=*/{TensorType_FLOAT32, {}});
+ m.SetInput({
+ 0, 6, 2, 4, //
+ 3, 2, 10, 7, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10}));
+}
+
+TEST(QuantizedPoolingOpTest, MaxPool) {
+ // Choose the input ranges carefully so that the dequantized output matches
+ // the results of the float model above.
+ QuantizedPoolingOpModel m(
+ BuiltinOperator_MAX_POOL_2D,
+ /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 0, 15.9375},
+ /*filter_width=*/2, /*filter_height=*/2,
+ /*output=*/{TensorType_UINT8, {}, 0, 15.9375});
+ m.SetInput({
+ 0, 6, 2, 4, //
+ 3, 2, 10, 7, //
+ });
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({6, 10})));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({96, 160}));
+}
+
+TEST(FloatPoolingOpTest, L2Pool) {
+ FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
+ /*filter_width=*/2, /*filter_height=*/2,
+ /*output=*/{TensorType_FLOAT32, {}});
+ m.SetInput({
+ 0, 6, 2, 4, //
+ 3, 2, 10, 7, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
new file mode 100644
index 0000000000..ca7a0dd194
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -0,0 +1,109 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/kernels/register.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+
+TfLiteRegistration* Register_RELU();
+TfLiteRegistration* Register_RELU1();
+TfLiteRegistration* Register_RELU6();
+TfLiteRegistration* Register_TANH();
+TfLiteRegistration* Register_LOGISTIC();
+TfLiteRegistration* Register_AVERAGE_POOL_2D();
+TfLiteRegistration* Register_MAX_POOL_2D();
+TfLiteRegistration* Register_L2_POOL_2D();
+TfLiteRegistration* Register_CONV_2D();
+TfLiteRegistration* Register_DEPTHWISE_CONV_2D();
+TfLiteRegistration* Register_SVDF();
+TfLiteRegistration* Register_RNN();
+TfLiteRegistration* Register_EMBEDDING_LOOKUP();
+TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE();
+TfLiteRegistration* Register_FULLY_CONNECTED();
+TfLiteRegistration* Register_LSH_PROJECTION();
+TfLiteRegistration* Register_HASHTABLE_LOOKUP();
+TfLiteRegistration* Register_SOFTMAX();
+TfLiteRegistration* Register_CONCATENATION();
+TfLiteRegistration* Register_ADD();
+TfLiteRegistration* Register_MUL();
+TfLiteRegistration* Register_L2_NORMALIZATION();
+TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION();
+TfLiteRegistration* Register_LSTM();
+TfLiteRegistration* Register_RESHAPE();
+TfLiteRegistration* Register_RESIZE_BILINEAR();
+TfLiteRegistration* Register_SKIP_GRAM();
+TfLiteRegistration* Register_SPACE_TO_DEPTH();
+
+BuiltinOpResolver::BuiltinOpResolver() {
+ AddBuiltin(BuiltinOperator_RELU, Register_RELU());
+ AddBuiltin(BuiltinOperator_RELU1, Register_RELU1());
+ AddBuiltin(BuiltinOperator_RELU6, Register_RELU6());
+ AddBuiltin(BuiltinOperator_TANH, Register_TANH());
+ AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC());
+ AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D());
+ AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D());
+ AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D());
+ AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
+ AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
+ AddBuiltin(BuiltinOperator_SVDF, Register_SVDF());
+ AddBuiltin(BuiltinOperator_RNN, Register_RNN());
+ AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP());
+ AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE,
+ Register_EMBEDDING_LOOKUP_SPARSE());
+ AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED());
+ AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION());
+ AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP());
+ AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX());
+ AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION());
+ AddBuiltin(BuiltinOperator_ADD, Register_ADD());
+ AddBuiltin(BuiltinOperator_MUL, Register_MUL());
+ AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION());
+ AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
+ Register_LOCAL_RESPONSE_NORMALIZATION());
+ AddBuiltin(BuiltinOperator_LSTM, Register_LSTM());
+ AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE());
+ AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR());
+ AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
+ AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH());
+}
+
+TfLiteRegistration* BuiltinOpResolver::FindOp(
+ tflite::BuiltinOperator op) const {
+ auto it = builtins_.find(op);
+ return it != builtins_.end() ? it->second : nullptr;
+}
+
+TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op) const {
+ auto it = custom_ops_.find(op);
+ return it != custom_ops_.end() ? it->second : nullptr;
+}
+
+void BuiltinOpResolver::AddBuiltin(tflite::BuiltinOperator op,
+ TfLiteRegistration* registration) {
+ registration->builtin_code = op;
+ builtins_.insert(std::make_pair(op, registration));
+}
+
+void BuiltinOpResolver::AddCustom(const char* name,
+ TfLiteRegistration* registration) {
+ registration->builtin_code = BuiltinOperator_CUSTOM;
+ custom_ops_.insert(std::make_pair(std::string(name), registration));
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h
new file mode 100644
index 0000000000..28f5e0fcc8
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/register.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
+
+#include <unordered_map>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+
+class BuiltinOpResolver : public OpResolver {
+ public:
+ BuiltinOpResolver();
+ TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override;
+ TfLiteRegistration* FindOp(const char* op) const override;
+ void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration);
+ void AddCustom(const char* name, TfLiteRegistration* registration);
+
+ private:
+ struct BuiltinOperatorHasher {
+ size_t operator()(const tflite::BuiltinOperator& x) const {
+ return std::hash<size_t>()(static_cast<size_t>(x));
+ }
+ };
+ std::unordered_map<tflite::BuiltinOperator, TfLiteRegistration*,
+ BuiltinOperatorHasher>
+ builtins_;
+ std::unordered_map<std::string, TfLiteRegistration*> custom_ops_;
+};
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc
new file mode 100644
index 0000000000..f3e6ddc9f4
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/reshape.cc
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace reshape {
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
+
+ // TODO(ahentz): we are often given a tensor with the shape but we only pay
+ // attention to what the shape specified in 'params'.
+ TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // Tensorflow's Reshape allows one of the shape components to have the
+ // special -1 value, meaning it will be calculated automatically based on the
+ // input. Here we calculate what that dimension should be so that the number
+ // of output elements in the same as the number of input elements.
+ int num_input_elements = 1;
+ for (int i = 0; i < NumDimensions(input); ++i) {
+ num_input_elements *= SizeOfDimension(input, i);
+ }
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions);
+ int num_output_elements = 1;
+ int strech_dim = -1;
+ for (int i = 0; i < params->num_dimensions; ++i) {
+ int value = params->shape[i];
+ if (value == -1) {
+ TF_LITE_ENSURE_EQ(context, strech_dim, -1);
+ strech_dim = i;
+ } else {
+ num_output_elements *= value;
+ output_size->data[i] = value;
+ }
+ }
+ if (strech_dim != -1) {
+ output_size->data[strech_dim] = num_input_elements / num_output_elements;
+ num_output_elements *= output_size->data[strech_dim];
+ }
+
+ TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
+ return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ memcpy(output->data.raw, input->data.raw, input->bytes);
+
+ return kTfLiteOk;
+}
+
+} // namespace reshape
+
+TfLiteRegistration* Register_RESHAPE() {
+ static TfLiteRegistration r = {nullptr, nullptr, reshape::Prepare,
+ reshape::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc
new file mode 100644
index 0000000000..59ce7d5648
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/reshape_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class ReshapeOpModel : public SingleOpModel {
+ public:
+ ReshapeOpModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> new_shape) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(
+ BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions,
+ CreateReshapeOptions(builder_, builder_.CreateVector<int>(new_shape))
+ .Union());
+ BuildInterpreter({input_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(ReshapeOpTest, MismatchedDimensions) {
+ EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {2, 1}),
+ "num_input_elements != num_output_elements");
+}
+
+TEST(ReshapeOpTest, TooManyDimensions) {
+ EXPECT_DEATH(
+ ReshapeOpModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}),
+ "Found too many dimensions");
+}
+
+TEST(ReshapeOpTest, TooManySpecialDimensions) {
+ EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {-1, -1, 2, 4}),
+ "strech_dim != -1");
+}
+
+TEST(ReshapeOpTest, SimpleTest) {
+ ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2});
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
+}
+
+TEST(ReshapeOpTest, WithStretchDimension) {
+ ReshapeOpModel m({1, 2, 4, 1}, {2, 1, -1});
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 4}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
new file mode 100644
index 0000000000..1613c9a89f
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -0,0 +1,129 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace resize_bilinear {
+
+// This file has three implementation of RESIZE_BILINEAR.
+enum KernelType {
+ kReference,
+ kGenericOptimized, // Neon-free
+ kNeonOptimized,
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // TODO(ahentz): Our current implementations rely on the inputs being 4D.
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+
+ // TODO(ahentz): Our current implementations only support float32.
+ TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+ output_size->data[0] = input->dims->data[0];
+ output_size->data[1] = params->new_height;
+ output_size->data[2] = params->new_width;
+ output_size->data[3] = input->dims->data[3];
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // We have to fake a tensor here, to satisfy ResizeBilinear().
+ int32 output_size_data[2] = {params->new_height, params->new_width};
+
+ if (output->type == kTfLiteFloat32) {
+#define TF_LITE_RESIZE_BILINEAR(type) \
+ type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \
+ output_size_data, GetTensorDims({1, 1, 1, 2}), \
+ GetTensorData<float>(output), GetTensorDims(output))
+
+ if (kernel_type == kReference) {
+ TF_LITE_RESIZE_BILINEAR(reference_ops);
+ }
+ if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
+ TF_LITE_RESIZE_BILINEAR(optimized_ops);
+ }
+#undef TF_LITE_RESIZE_BILINEAR
+ } else {
+ context->ReportError(context, "Inputs and outputs not all float types.");
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace resize_bilinear
+
+TfLiteRegistration* Register_RESIZE_BILINEAR_REF() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, resize_bilinear::Prepare,
+ resize_bilinear::Eval<resize_bilinear::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_RESIZE_BILINEAR_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, resize_bilinear::Prepare,
+ resize_bilinear::Eval<resize_bilinear::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_RESIZE_BILINEAR_NEON_OPT() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, resize_bilinear::Prepare,
+ resize_bilinear::Eval<resize_bilinear::kNeonOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_RESIZE_BILINEAR() {
+#ifdef USE_NEON
+ return Register_RESIZE_BILINEAR_NEON_OPT();
+#else
+ return Register_RESIZE_BILINEAR_GENERIC_OPT();
+#endif
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
new file mode 100644
index 0000000000..0257c0b557
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
@@ -0,0 +1,117 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class ResizeBilinearOpModel : public SingleOpModel {
+ public:
+ ResizeBilinearOpModel(std::initializer_list<int> input_shape, int new_height,
+ int new_width) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(
+ BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions,
+ CreateResizeBilinearOptions(builder_, new_height, new_width).Union());
+ BuildInterpreter({input_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(ResizeBilinearOpTest, HorizontalResize) {
+ ResizeBilinearOpModel m({1, 1, 2, 1}, 1, 3);
+ m.SetInput({3, 6});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
+}
+
+TEST(ResizeBilinearOpTest, VerticalResize) {
+ ResizeBilinearOpModel m({1, 2, 1, 1}, 3, 1);
+ m.SetInput({3, 9});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
+}
+
+TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
+ ResizeBilinearOpModel m({1, 2, 2, 1}, 3, 3);
+ m.SetInput({
+ 3, 6, //
+ 9, 12 //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ })));
+}
+
+TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
+ ResizeBilinearOpModel m({2, 2, 2, 1}, 3, 3);
+ m.SetInput({
+ 3, 6, //
+ 9, 12, //
+ 4, 10, //
+ 10, 16 //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ 4, 8, 10, //
+ 8, 12, 14, //
+ 10, 14, 16, //
+ })));
+}
+
+TEST(ResizeBilinearOpTest, ThreeDimensionalResize) {
+ ResizeBilinearOpModel m({1, 2, 2, 2}, 3, 3);
+ m.SetInput({
+ 3, 4, 6, 10, //
+ 9, 10, 12, 16, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 3, 4, 5, 8, 6, 10, //
+ 7, 8, 9, 12, 10, 14, //
+ 9, 10, 11, 14, 12, 16, //
+ })));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/contrib/lite/kernels/skip_gram.cc
new file mode 100644
index 0000000000..c90a15b3a2
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/skip_gram.cc
@@ -0,0 +1,160 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Generate a list of skip grams from an input.
+//
+// Options:
+// ngram_size: num of words for each output item.
+// max_skip_size: max num of words to skip.
+// The op generates ngrams when it is 0.
+// include_all_ngrams: include all ngrams with size up to ngram_size.
+//
+// Input:
+// A string tensor to generate n-grams.
+// Dim = {1}
+//
+// Output:
+// A list of strings, each of which contains ngram_size words.
+// Dim = {num_ngram}
+
+#include <ctype.h>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+
+namespace {
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TF_LITE_ENSURE_EQ(context, GetInput(context, node, 0)->type, kTfLiteString);
+ TF_LITE_ENSURE_EQ(context, GetOutput(context, node, 0)->type, kTfLiteString);
+ return kTfLiteOk;
+}
+
+bool ShouldIncludeCurrentNgram(const TfLiteSkipGramParams* params, int size) {
+ if (size <= 0) {
+ return false;
+ }
+ if (params->include_all_ngrams) {
+ return size <= params->ngram_size;
+ } else {
+ return size == params->ngram_size;
+ }
+}
+
+bool ShouldStepInRecursion(const TfLiteSkipGramParams* params,
+ const std::vector<int>& stack, int stack_idx,
+ int num_words) {
+ // If current stack size and next word enumeration are within valid range.
+ if (stack_idx < params->ngram_size && stack[stack_idx] + 1 < num_words) {
+ // If this stack is empty, step in for first word enumeration.
+ if (stack_idx == 0) {
+ return true;
+ }
+ // If next word enumeration are within the range of max_skip_size.
+ // NOTE: equivalent to
+ // next_word_idx = stack[stack_idx] + 1
+ // next_word_idx - stack[stack_idx-1] <= max_skip_size + 1
+ if (stack[stack_idx] - stack[stack_idx - 1] <= params->max_skip_size) {
+ return true;
+ }
+ }
+ return false;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSkipGramParams*>(node->builtin_data);
+
+ // Split sentence to words.
+ std::vector<StringRef> words;
+ tflite::StringRef strref = tflite::GetString(GetInput(context, node, 0), 0);
+ int prev_idx = 0;
+ for (int i = 1; i < strref.len; i++) {
+ if (isspace(*(strref.str + i))) {
+ if (i > prev_idx && !isspace(*(strref.str + prev_idx))) {
+ words.push_back({strref.str + prev_idx, i - prev_idx});
+ }
+ prev_idx = i + 1;
+ }
+ }
+ if (strref.len > prev_idx) {
+ words.push_back({strref.str + prev_idx, strref.len - prev_idx});
+ }
+
+ // Generate n-grams recursively.
+ tflite::DynamicBuffer buf;
+ if (words.size() < params->ngram_size) {
+ buf.WriteToTensor(GetOutput(context, node, 0));
+ return kTfLiteOk;
+ }
+
+ // Stack stores the index of word used to generate ngram.
+ // The size of stack is the size of ngram.
+ std::vector<int> stack(params->ngram_size, 0);
+ // Stack index that indicates which depth the recursion is operating at.
+ int stack_idx = 1;
+ int num_words = words.size();
+
+ while (stack_idx >= 0) {
+ if (ShouldStepInRecursion(params, stack, stack_idx, num_words)) {
+ // When current depth can fill with a new word
+ // and the new word is within the max range to skip,
+ // fill this word to stack, recurse into next depth.
+ stack[stack_idx]++;
+ stack_idx++;
+ if (stack_idx < params->ngram_size) {
+ stack[stack_idx] = stack[stack_idx - 1];
+ }
+ } else {
+ if (ShouldIncludeCurrentNgram(params, stack_idx)) {
+ // Add n-gram to tensor buffer when the stack has filled with enough
+ // words to generate the ngram.
+ std::vector<StringRef> gram(stack_idx);
+ for (int i = 0; i < stack_idx; i++) {
+ gram[i] = words[stack[i]];
+ }
+ buf.AddJoinedString(gram, ' ');
+ }
+ // When current depth cannot fill with a valid new word,
+ // and not in last depth to generate ngram,
+ // step back to previous depth to iterate to next possible word.
+ stack_idx--;
+ }
+ }
+
+ buf.WriteToTensor(GetOutput(context, node, 0));
+ return kTfLiteOk;
+}
+} // namespace
+
+TfLiteRegistration* Register_SKIP_GRAM() {
+ static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/skip_gram_test.cc b/tensorflow/contrib/lite/kernels/skip_gram_test.cc
new file mode 100644
index 0000000000..e7f6bc904b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/skip_gram_test.cc
@@ -0,0 +1,257 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+static char kSentence[] = "The quick\t brown fox\n jumps over\n the lazy dog!";
+
+class SkipGramOp : public SingleOpModel {
+ public:
+ SkipGramOp(int ngram_size, int max_skip_size, bool include_all_ngrams) {
+ input_ = AddInput(TensorType_STRING);
+ output_ = AddOutput(TensorType_STRING);
+
+ SetBuiltinOp(BuiltinOperator_SKIP_GRAM, BuiltinOptions_SkipGramOptions,
+ CreateSkipGramOptions(builder_, ngram_size, max_skip_size,
+ include_all_ngrams)
+ .Union());
+ BuildInterpreter({{1}});
+ }
+ void SetInput(const string& content) {
+ PopulateStringTensor(input_, {content});
+ }
+
+ std::vector<string> GetOutput() {
+ std::vector<string> ans;
+ TfLiteTensor* tensor = interpreter_->tensor(output_);
+
+ int num = GetStringCount(tensor);
+ for (int i = 0; i < num; i++) {
+ StringRef strref = GetString(tensor, i);
+ ans.push_back(string(strref.str, strref.len));
+ }
+ return ans;
+ }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(SkipGramTest, TestUnigram) {
+ SkipGramOp m(1, 0, false);
+
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), testing::UnorderedElementsAreArray(
+ {"The", "quick", "brown", "fox", "jumps",
+ "over", "the", "lazy", "dog!"}));
+}
+
+TEST(SkipGramTest, TestBigram) {
+ SkipGramOp m(2, 0, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick", "quick brown", "brown fox", "fox jumps",
+ "jumps over", "over the", "the lazy", "lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestAllBigram) {
+ SkipGramOp m(2, 0, true);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {// Unigram
+ "The", "quick", "brown", "fox", "jumps", "over", "the",
+ "lazy", "dog!",
+ // Bigram
+ "The quick", "quick brown", "brown fox", "fox jumps",
+ "jumps over", "over the", "the lazy", "lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestAllTrigram) {
+ SkipGramOp m(3, 0, true);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {// Unigram
+ "The", "quick", "brown", "fox", "jumps", "over", "the",
+ "lazy", "dog!",
+ // Bigram
+ "The quick", "quick brown", "brown fox", "fox jumps",
+ "jumps over", "over the", "the lazy", "lazy dog!",
+ // Trigram
+ "The quick brown", "quick brown fox", "brown fox jumps",
+ "fox jumps over", "jumps over the", "over the lazy",
+ "the lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSkip1Bigram) {
+ SkipGramOp m(2, 1, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick", "The brown", "quick brown", "quick fox", "brown fox",
+ "brown jumps", "fox jumps", "fox over", "jumps over", "jumps the",
+ "over the", "over lazy", "the lazy", "the dog!", "lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSkip2Bigram) {
+ SkipGramOp m(2, 2, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick", "The brown", "The fox", "quick brown",
+ "quick fox", "quick jumps", "brown fox", "brown jumps",
+ "brown over", "fox jumps", "fox over", "fox the",
+ "jumps over", "jumps the", "jumps lazy", "over the",
+ "over lazy", "over dog!", "the lazy", "the dog!",
+ "lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSkip1Trigram) {
+ SkipGramOp m(3, 1, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick brown", "The quick fox", "The brown fox",
+ "The brown jumps", "quick brown fox", "quick brown jumps",
+ "quick fox jumps", "quick fox over", "brown fox jumps",
+ "brown fox over", "brown jumps over", "brown jumps the",
+ "fox jumps over", "fox jumps the", "fox over the",
+ "fox over lazy", "jumps over the", "jumps over lazy",
+ "jumps the lazy", "jumps the dog!", "over the lazy",
+ "over the dog!", "over lazy dog!", "the lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSkip2Trigram) {
+ SkipGramOp m(3, 2, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick brown", "The quick fox", "The quick jumps",
+ "The brown fox", "The brown jumps", "The brown over",
+ "The fox jumps", "The fox over", "The fox the",
+ "quick brown fox", "quick brown jumps", "quick brown over",
+ "quick fox jumps", "quick fox over", "quick fox the",
+ "quick jumps over", "quick jumps the", "quick jumps lazy",
+ "brown fox jumps", "brown fox over", "brown fox the",
+ "brown jumps over", "brown jumps the", "brown jumps lazy",
+ "brown over the", "brown over lazy", "brown over dog!",
+ "fox jumps over", "fox jumps the", "fox jumps lazy",
+ "fox over the", "fox over lazy", "fox over dog!",
+ "fox the lazy", "fox the dog!", "jumps over the",
+ "jumps over lazy", "jumps over dog!", "jumps the lazy",
+ "jumps the dog!", "jumps lazy dog!", "over the lazy",
+ "over the dog!", "over lazy dog!", "the lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestAllSkip2Trigram) {
+ SkipGramOp m(3, 2, true);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {// Unigram
+ "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy",
+ "dog!",
+ // Bigram
+ "The quick", "The brown", "The fox", "quick brown", "quick fox",
+ "quick jumps", "brown fox", "brown jumps", "brown over", "fox jumps",
+ "fox over", "fox the", "jumps over", "jumps the", "jumps lazy",
+ "over the", "over lazy", "over dog!", "the lazy", "the dog!",
+ "lazy dog!",
+ // Trigram
+ "The quick brown", "The quick fox", "The quick jumps",
+ "The brown fox", "The brown jumps", "The brown over",
+ "The fox jumps", "The fox over", "The fox the", "quick brown fox",
+ "quick brown jumps", "quick brown over", "quick fox jumps",
+ "quick fox over", "quick fox the", "quick jumps over",
+ "quick jumps the", "quick jumps lazy", "brown fox jumps",
+ "brown fox over", "brown fox the", "brown jumps over",
+ "brown jumps the", "brown jumps lazy", "brown over the",
+ "brown over lazy", "brown over dog!", "fox jumps over",
+ "fox jumps the", "fox jumps lazy", "fox over the", "fox over lazy",
+ "fox over dog!", "fox the lazy", "fox the dog!", "jumps over the",
+ "jumps over lazy", "jumps over dog!", "jumps the lazy",
+ "jumps the dog!", "jumps lazy dog!", "over the lazy",
+ "over the dog!", "over lazy dog!", "the lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSingleWord) {
+ SkipGramOp m(1, 1, false);
+ m.SetInput("Hi");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAre("Hi"));
+}
+
+TEST(SkipGramTest, TestWordsLessThanGram) {
+ SkipGramOp m(3, 1, false);
+ m.SetInput("Hi hi");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), std::vector<string>());
+}
+
+TEST(SkipGramTest, TestEmptyInput) {
+ SkipGramOp m(1, 1, false);
+ m.SetInput("");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAre());
+}
+
+TEST(SkipGramTest, TestWhitespaceInput) {
+ SkipGramOp m(1, 1, false);
+ m.SetInput(" ");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAre());
+}
+
+TEST(SkipGramTest, TestInputWithExtraSpace) {
+ SkipGramOp m(1, 1, false);
+ m.SetInput(" Hello world ! ");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAre("Hello", "world", "!"));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc
new file mode 100644
index 0000000000..ec8ec03b0d
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/softmax_test.cc
@@ -0,0 +1,143 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite SOFTMAX op.
+
+#include <iomanip>
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+class SoftmaxOpModel : public SingleOpModel {
+ public:
+ SoftmaxOpModel(int batches, int size, float beta)
+ : batches_(batches), input_size_(size), beta_(beta) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
+ CreateSoftmaxOptions(builder_, beta_).Union());
+ BuildInterpreter({{batches_, input_size_}});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ private:
+ int input_;
+ int output_;
+
+ int batches_;
+ int input_size_;
+ float beta_;
+};
+
+TEST(SoftmaxOpTest, SimpleTest) {
+ SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0);
+ m.SetInput({
+ 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0
+ -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647,
+ 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231},
+ 1e-6)));
+}
+
+TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) {
+ const int batch_size = 2;
+ const int input_size = 5;
+ const float beta = 1.0;
+ static float input_buffer[] = {
+ 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0
+ -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1
+ };
+
+ SoftmaxOpModel m(batch_size, input_size, beta);
+
+ m.SetInput(0, input_buffer, input_buffer + input_size * batch_size);
+
+ m.Invoke();
+
+ std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
+ static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
+ {1, 0, 0, input_size}};
+ tflite::reference_ops::Softmax(input_buffer, input_dims, beta,
+ output_buffer.get(), input_dims);
+
+ std::vector<float> expected;
+ expected.insert(expected.end(), output_buffer.get(),
+ output_buffer.get() + input_size * batch_size);
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6)));
+}
+
+TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) {
+ const int batch_size = 2;
+ const int input_size = 5;
+ const float beta = 0.5;
+ static float input_buffer[] = {
+ 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0
+ -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1
+ };
+
+ SoftmaxOpModel m(batch_size, input_size, beta);
+
+ m.SetInput(0, input_buffer, input_buffer + input_size * batch_size);
+
+ m.Invoke();
+
+ std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
+ static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
+ {1, 0, 0, input_size}};
+ tflite::reference_ops::Softmax(input_buffer, input_dims, beta,
+ output_buffer.get(), input_dims);
+
+ std::vector<float> expected;
+ expected.insert(expected.end(), output_buffer.get(),
+ output_buffer.get() + input_size * batch_size);
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6)));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc
new file mode 100644
index 0000000000..cb2e509c98
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc
@@ -0,0 +1,146 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace space_to_depth {
+
+// This file has two implementation of SpaceToDepth. Note that SpaceToDepth
+// only works on 4D tensors.
+enum KernelType {
+ kReference,
+ kGenericOptimized,
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+
+ auto data_type = output->type;
+ TF_LITE_ENSURE(context,
+ data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 ||
+ data_type == kTfLiteInt32 || data_type == kTfLiteInt64);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ const int block_size = params->block_size;
+ const int input_height = input->dims->data[1];
+ const int input_width = input->dims->data[2];
+ int output_height = input_height / block_size;
+ int output_width = input_width / block_size;
+
+ TF_LITE_ENSURE_EQ(context, input_height, output_height * block_size);
+ TF_LITE_ENSURE_EQ(context, input_width, output_width * block_size);
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+ output_size->data[0] = input->dims->data[0];
+ output_size->data[1] = output_height;
+ output_size->data[2] = output_width;
+ output_size->data[3] = input->dims->data[3] * block_size * block_size;
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
+ type::SpaceToDepth<scalar>( \
+ GetTensorData<scalar>(input), GetTensorDims(input), params->block_size, \
+ GetTensorData<scalar>(output), GetTensorDims(output))
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ if (kernel_type == kReference) {
+ TF_LITE_SPACE_TO_DEPTH(reference_ops, float);
+ } else {
+ TF_LITE_SPACE_TO_DEPTH(optimized_ops, float);
+ }
+ break;
+ case kTfLiteUInt8:
+ if (kernel_type == kReference) {
+ TF_LITE_SPACE_TO_DEPTH(reference_ops, uint8_t);
+ } else {
+ TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t);
+ }
+ break;
+ case kTfLiteInt32:
+ if (kernel_type == kReference) {
+ TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t);
+ } else {
+ TF_LITE_SPACE_TO_DEPTH(optimized_ops, int32_t);
+ }
+ break;
+ case kTfLiteInt64:
+ if (kernel_type == kReference) {
+ TF_LITE_SPACE_TO_DEPTH(reference_ops, int64_t);
+ } else {
+ TF_LITE_SPACE_TO_DEPTH(optimized_ops, int64_t);
+ }
+ break;
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+#undef TF_LITE_SPACE_TO_DEPTH
+
+ return kTfLiteOk;
+}
+
+} // namespace space_to_depth
+
+TfLiteRegistration* Register_SPACE_TO_DEPTH_REF() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, space_to_depth::Prepare,
+ space_to_depth::Eval<space_to_depth::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_SPACE_TO_DEPTH_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, space_to_depth::Prepare,
+ space_to_depth::Eval<space_to_depth::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_SPACE_TO_DEPTH() {
+ return Register_SPACE_TO_DEPTH_GENERIC_OPT();
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/space_to_depth_test.cc b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc
new file mode 100644
index 0000000000..911f08a92c
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc
@@ -0,0 +1,102 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+class SpaceToDepthOpModel : public SingleOpModel {
+ public:
+ SpaceToDepthOpModel(const TensorData& tensor_data, int block_size) {
+ input_ = AddInput(tensor_data);
+ output_ = AddOutput(tensor_data);
+ SetBuiltinOp(BuiltinOperator_SPACE_TO_DEPTH,
+ BuiltinOptions_SpaceToDepthOptions,
+ CreateSpaceToDepthOptions(builder_, block_size).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ template <typename T>
+ void SetInput(std::initializer_list<T> data) {
+ PopulateTensor<T>(input_, data);
+ }
+ template <typename T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(SpaceToDepthOpModel, BadBlockSize) {
+ EXPECT_DEATH(SpaceToDepthOpModel({TensorType_FLOAT32, {1, 2, 2, 1}}, 3),
+ "Cannot allocate tensors");
+}
+
+TEST(SpaceToDepthOpModel, Float32) {
+ SpaceToDepthOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, 2);
+ m.SetInput<float>({1.4, 2.3, 3.2, 4.1, 5.4, 6.3, 7.2, 8.1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray({1.4, 2.3, 3.2, 4.1, 5.4, 6.3, 7.2, 8.1}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 8));
+}
+
+TEST(SpaceToDepthOpModel, Uint8) {
+ SpaceToDepthOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, 2);
+ m.SetInput<uint8_t>({1, 2, 3, 4});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({1, 2, 3, 4}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(SpaceToDepthOpModel, Int32) {
+ SpaceToDepthOpModel m({TensorType_INT32, {1, 2, 2, 3}}, 2);
+ m.SetInput<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<int32_t>(),
+ ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 12));
+}
+
+TEST(SpaceToDepthOpModel, Int64) {
+ SpaceToDepthOpModel m({TensorType_INT64, {1, 4, 4, 1}}, 2);
+ m.SetInput<int64_t>({1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<int64_t>(),
+ ElementsAreArray(
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 2, 2, 4));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc
new file mode 100644
index 0000000000..dd414d53bd
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/svdf.cc
@@ -0,0 +1,224 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdlib>
+#include <cstdio>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace svdf {
+
+constexpr int kInputTensor = 0;
+constexpr int kWeightsFeatureTensor = 1;
+constexpr int kWeightsTimeTensor = 2;
+constexpr int kBiasTensor = 3;
+constexpr int kStateTensor = 0;
+constexpr int KOutputTensor = 1;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, 1, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
+ // Check we have all the inputs and outputs we need.
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+
+ TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
+ TfLiteTensor* weights_feature =
+ &context->tensors[node->inputs->data[kWeightsFeatureTensor]];
+ TfLiteTensor* weights_time =
+ &context->tensors[node->inputs->data[kWeightsTimeTensor]];
+
+ // Check all the parameters of tensor match within themselves and match the
+ // input configuration.
+ const int rank = params->rank;
+ const int batch_size = input->dims->data[0];
+ const int num_filters = weights_feature->dims->data[0];
+ TF_LITE_ASSERT_EQ(num_filters % rank, 0);
+ const int num_units = num_filters / rank;
+ const int memory_size = weights_time->dims->data[1];
+ TF_LITE_ASSERT_EQ(input->dims->data[1], weights_feature->dims->data[1]);
+ TF_LITE_ASSERT_EQ(weights_time->dims->data[0], num_filters);
+
+ TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+ if (bias) {
+ TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units);
+ }
+
+ TfLiteTensor* state = &context->tensors[node->outputs->data[kStateTensor]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[KOutputTensor]];
+
+ // Resize state.
+ // For each batch, the state is a 2-D tensor: memory_size * num_filters
+ // The left most column is used to save current cycle activation.
+ // The right most column is used to save temporary output which will be
+ // reduced to num_units outputs.
+ TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2);
+ state_size_array->data[0] = batch_size;
+ state_size_array->data[1] = memory_size * num_filters;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, state, state_size_array));
+
+ // Mark state as a persistent tensor.
+ state->allocation_type = kTfLiteArenaRwPersistent;
+
+ // Resize output.
+ TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
+ output_size_array->data[0] = batch_size;
+ output_size_array->data[1] = num_units;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size_array));
+
+ // Resize scratch.
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(1);
+ node->temporaries->data[0] = *scratch_tensor_index;
+
+ TfLiteIntArray* scratch_size_array = TfLiteIntArrayCreate(2);
+ scratch_size_array->data[0] = batch_size;
+ scratch_size_array->data[1] = num_filters;
+
+ TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]];
+ scratch_tensor->type = input->type;
+ scratch_tensor->allocation_type = kTfLiteArenaRw;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor,
+ scratch_size_array));
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
+
+ TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
+ TfLiteTensor* weights_feature =
+ &context->tensors[node->inputs->data[kWeightsFeatureTensor]];
+ TfLiteTensor* weights_time =
+ &context->tensors[node->inputs->data[kWeightsTimeTensor]];
+
+ TfLiteTensor* state = &context->tensors[node->outputs->data[kStateTensor]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[KOutputTensor]];
+ TfLiteTensor* scratch = &context->tensors[node->temporaries->data[0]];
+
+ TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+
+ const int rank = params->rank;
+ const int batch_size = input->dims->data[0];
+ const int input_size = input->dims->data[1];
+ const int num_filters = weights_feature->dims->data[0];
+ const int num_units = num_filters / rank;
+ const int memory_size = weights_time->dims->data[1];
+
+ // Clear the activation (state left most column).
+ // TODO(ghodrat): Add a test which initialize state with invalid values in
+ // left most column and make sure it passes.
+ for (int b = 0; b < batch_size; b++) {
+ float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
+ for (int c = 0; c < num_filters; c++) {
+ float* state_ptr = state_ptr_batch + c * memory_size;
+ state_ptr[memory_size - 1] = 0.0;
+ }
+ }
+
+ // Compute conv1d(inputs, weights_feature).
+ // The state left most column is used to save current cycle activation. This
+ // is achieved by starting at state->data.f[memory_size - 1] and having the
+ // stride equal to memory_size.
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ weights_feature->data.f, num_filters, input_size, input->data.f,
+ batch_size, &state->data.f[memory_size - 1], memory_size);
+
+ // Compute matmul(state, weights_time).
+ // The right most column is used to save temporary output (with the size of
+ // num_filters). This is achieved by starting at state->data.f and having the
+ // stride equal to memory_size.
+ for (int b = 0; b < batch_size; b++) {
+ float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
+ float* scratch_ptr_batch = scratch->data.f + b * num_filters;
+ tensor_utils::BatchVectorBatchVectorDotProduct(
+ weights_time->data.f, state_ptr_batch, memory_size, num_filters,
+ scratch_ptr_batch, /*result_stride=*/1);
+ }
+
+ // Initialize output with bias if provided.
+ if (bias) {
+ tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
+ output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
+ }
+
+ // Reduction sum
+ // TODO(ghodrat): Consider not reusing state for the temporary output, this
+ // way ReductionSum operates on row-vector instead of column vector.
+ for (int b = 0; b < batch_size; b++) {
+ float* output_ptr_batch = output->data.f + b * num_units;
+ float* scratch_ptr_batch = scratch->data.f + b * num_filters;
+ tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch,
+ num_units, rank);
+ }
+
+ // Apply activation.
+ for (int b = 0; b < batch_size; b++) {
+ float* output_ptr_batch = output->data.f + b * num_units;
+ tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units,
+ params->activation, output_ptr_batch);
+ }
+
+ // Right shift the state.
+ for (int b = 0; b < batch_size; b++) {
+ float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
+ for (int f = 0; f < num_filters; f++) {
+ tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size,
+ /*shift_value=*/0.0);
+ state_ptr_batch += memory_size;
+ }
+ }
+ return kTfLiteOk;
+}
+
+} // namespace svdf
+
+TfLiteRegistration* Register_SVDF() {
+ static TfLiteRegistration r = {svdf::Init, svdf::Free, svdf::Prepare,
+ svdf::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc
new file mode 100644
index 0000000000..d956025e9d
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/svdf_test.cc
@@ -0,0 +1,312 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite SVDF op.
+
+#include <vector>
+#include <iomanip>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+static float svdf_input[] = {
+ 0.12609188, -0.46347019, -0.89598465,
+ 0.35867718, 0.36897406, 0.73463392,
+
+ 0.14278367, -1.64410412, -0.75222826,
+ -0.57290924, 0.12729003, 0.7567004,
+
+ 0.49837467, 0.19278903, 0.26584083,
+ 0.17660543, 0.52949083, -0.77931279,
+
+ -0.11186574, 0.13164264, -0.05349274,
+ -0.72674477, -0.5683046, 0.55900657,
+
+ -0.68892461, 0.37783599, 0.18263303,
+ -0.63690937, 0.44483393, -0.71817774,
+
+ -0.81299269, -0.86831826, 1.43940818,
+ -0.95760226, 1.82078898, 0.71135032,
+
+ -1.45006323, -0.82251364, -1.69082689,
+ -1.65087092, -1.89238167, 1.54172635,
+
+ 0.03966608, -0.24936394, -0.77526885,
+ 2.06740379, -1.51439476, 1.43768692,
+
+ 0.11771342, -0.23761693, -0.65898693,
+ 0.31088525, -1.55601168, -0.87661445,
+
+ -0.89477462, 1.67204106, -0.53235275,
+ -0.6230064, 0.29819036, 1.06939757,
+};
+
+static float svdf_golden_output_rank_1[] = {
+ 0.014899, -0.0517661, -0.143725, -0.00271883,
+ -0.03004015, 0.09565311, 0.1587342, 0.00784263,
+
+ 0.068281, -0.162217, -0.152268, 0.00323521,
+ 0.01582633, 0.03858774, -0.03001583, -0.02671271,
+
+ -0.0317821, -0.0333089, 0.0609602, 0.0333759,
+ -0.01432795, 0.05524484, 0.1101355, -0.02382665,
+
+ -0.00623099, -0.077701, -0.391193, -0.0136691,
+ -0.02333033, 0.02293761, 0.12338032, 0.04326871,
+
+ 0.201551, -0.164607, -0.179462, -0.0592739,
+ 0.01064911, -0.17503069, 0.07821996, -0.00224009,
+
+ 0.0886511, -0.0875401, -0.269283, 0.0281379,
+ -0.02282338, 0.09741908, 0.32973239, 0.12281385,
+
+ -0.201174, -0.586145, -0.628624, -0.0330412,
+ 0.24780814, -0.39304617, -0.22473189, 0.02589256,
+
+ -0.0839096, -0.299329, 0.108746, 0.109808,
+ 0.10084175, -0.06416984, 0.28936723, 0.0026358,
+
+ 0.419114, -0.237824, -0.422627, 0.175115,
+ -0.2314795, -0.18584411, -0.4228974, -0.12928449,
+
+ 0.36726, -0.522303, -0.456502, -0.175475,
+ 0.17012937, -0.34447709, 0.38505614, -0.28158101,
+};
+
+static float svdf_golden_output_rank_2[] = {
+ -0.09623547, -0.10193135, 0.11083051, -0.0347917,
+ 0.1141196, 0.12965347, -0.12652366, 0.01007236,
+
+ -0.16396809, -0.21247184, 0.11259045, -0.04156673,
+ 0.10132131, -0.06143532, -0.00924693, 0.10084561,
+
+ 0.01257364, 0.0506071, -0.19287863, -0.07162561,
+ -0.02033747, 0.22673416, 0.15487903, 0.02525555,
+
+ -0.1411963, -0.37054959, 0.01774767, 0.05867489,
+ 0.09607603, -0.0141301, -0.08995658, 0.12867066,
+
+ -0.27142537, -0.16955489, 0.18521598, -0.12528358,
+ 0.00331409, 0.11167502, 0.02218599, -0.07309391,
+
+ 0.09593632, -0.28361851, -0.0773851, 0.17199151,
+ -0.00075242, 0.33691186, -0.1536046, 0.16572715,
+
+ -0.27916506, -0.27626723, 0.42615682, 0.3225764,
+ -0.37472126, -0.55655634, -0.05013514, 0.289112,
+
+ -0.24418658, 0.07540751, -0.1940318, -0.08911639,
+ 0.00732617, 0.46737891, 0.26449674, 0.24888524,
+
+ -0.17225097, -0.54660404, -0.38795233, 0.08389944,
+ 0.07736043, -0.28260678, 0.15666828, 1.14949894,
+
+ -0.57454878, -0.64704704, 0.73235172, -0.34616736,
+ 0.21120001, -0.22927976, 0.02455296, -0.35906726,
+};
+
+// Derived class of SingleOpModel, which is used to test SVDF TFLite op.
+class SVDFOpModel : public SingleOpModel {
+ public:
+ SVDFOpModel(int batches, int units, int input_size, int memory_size, int rank)
+ : batches_(batches),
+ units_(units),
+ input_size_(input_size),
+ memory_size_(memory_size),
+ rank_(rank) {
+ input_ = AddInput(TensorType_FLOAT32);
+ weights_feature_ = AddInput(TensorType_FLOAT32);
+ weights_time_ = AddInput(TensorType_FLOAT32);
+ bias_ = AddNullInput();
+ state_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(
+ BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
+ CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
+ BuildInterpreter({
+ {batches_, input_size_}, // Input tensor
+ {units_ * rank, input_size_}, // weights_feature tensor
+ {units_ * rank, memory_size_}, // weights_time tensor
+ {units_} // bias tensor
+ });
+ }
+
+ // Populates the weights_feature tensor.
+ void SetWeightsFeature(std::initializer_list<float> f) {
+ PopulateTensor(weights_feature_, f);
+ }
+
+ // Populates the weights_time tensor.
+ void SetWeightsTime(std::initializer_list<float> f) {
+ PopulateTensor(weights_time_, f);
+ }
+
+ // Populates the input tensor.
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ // Resets the state of SVDF op by filling it with 0's.
+ void ResetState() {
+ const int zero_buffer_size = rank_ * units_ * batches_ * memory_size_;
+ std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
+ memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
+ PopulateTensor(state_, 0, zero_buffer.get(),
+ zero_buffer.get() + zero_buffer_size);
+ }
+
+ // Extracts the output tensor from the SVDF op.
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ int input_size() { return input_size_; }
+ int num_units() { return units_; }
+ int num_batches() { return batches_; }
+
+ private:
+ int input_;
+ int weights_feature_;
+ int weights_time_;
+ int bias_;
+ int state_;
+ int output_;
+
+ int batches_;
+ int units_;
+ int input_size_;
+ int memory_size_;
+ int rank_;
+};
+
+TEST(SVDFOpTest, BlackBoxTestRank1) {
+ SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
+ /*memory_size=*/10, /*rank=*/1);
+ svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
+ 0.22197971, 0.12416199, 0.27901134, 0.27557442,
+ 0.3905206, -0.36137494, -0.06634006, -0.10640851});
+
+ svdf.SetWeightsTime(
+ {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
+
+ svdf.ResetState();
+ const int svdf_num_batches = svdf.num_batches();
+ const int svdf_input_size = svdf.input_size();
+ const int svdf_num_units = svdf.num_units();
+ const int input_sequence_size =
+ sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches);
+ // Going over each input batch, setting the input tensor, invoking the SVDF op
+ // and checking the output with the expected golden values.
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches;
+ float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
+ svdf.SetInput(0, batch_start, batch_end);
+
+ svdf.Invoke();
+
+ float* golden_start =
+ svdf_golden_output_rank_1 + i * svdf_num_units * svdf_num_batches;
+ float* golden_end = golden_start + svdf_num_units * svdf_num_batches;
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
+}
+
+TEST(SVDFOpTest, BlackBoxTestRank2) {
+ SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
+ /*memory_size=*/10, /*rank=*/2);
+ svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
+ 0.12416199, 0.15785322, 0.27901134, 0.3905206,
+ 0.21931258, -0.36137494, -0.10640851, 0.31053296,
+ -0.36118156, -0.0976817, -0.36916667, 0.22197971,
+ 0.15294972, 0.38031587, 0.27557442, 0.39635518,
+ -0.21580373, -0.06634006, -0.02702999, 0.27072677});
+
+ svdf.SetWeightsTime(
+ {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
+
+ -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
+ 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
+
+ -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
+ 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
+
+ -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
+ -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
+
+ 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
+ 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
+
+ svdf.ResetState();
+ const int svdf_num_batches = svdf.num_batches();
+ const int svdf_input_size = svdf.input_size();
+ const int svdf_num_units = svdf.num_units();
+ const int input_sequence_size =
+ sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches);
+ // Going over each input batch, setting the input tensor, invoking the SVDF op
+ // and checking the output with the expected golden values.
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches;
+ float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
+ svdf.SetInput(0, batch_start, batch_end);
+
+ svdf.Invoke();
+
+ float* golden_start =
+ svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches;
+ float* golden_end = golden_start + svdf_num_units * svdf_num_batches;
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
new file mode 100644
index 0000000000..f716ba8741
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -0,0 +1,183 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+#include "tensorflow/contrib/lite/version.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tflite {
+
+using ::testing::FloatNear;
+using ::testing::Matcher;
+
+namespace {
+template <typename T>
+std::pair<float, int32_t> QuantizationParams(float f_min, float f_max) {
+ // These are required by many quantized operations.
+ CHECK_LE(f_min, 0);
+ CHECK_GE(f_max, 0);
+ T q_min = std::numeric_limits<T>::min();
+ T q_max = std::numeric_limits<T>::max();
+ float range = q_max - q_min;
+ float scale = (f_max - f_min) / range;
+ int32_t zero_point = std::min(
+ q_max,
+ std::max(q_min, static_cast<T>(std::round(q_min - f_min / scale))));
+ return {scale, zero_point};
+}
+} // namespace
+
+std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
+ float max_abs_error) {
+ std::vector<Matcher<float>> matchers;
+ matchers.reserve(values.size());
+ for (const float& v : values) {
+ matchers.emplace_back(FloatNear(v, max_abs_error));
+ }
+ return matchers;
+}
+
+int SingleOpModel::AddTensor(TensorData t) {
+ int id = tensors_.size();
+
+ // This is slightly different depending on whether we are adding a
+ // quantized or a regular tensor.
+ bool is_quantized = (t.min != 0 || t.max != 0 || t.scale != 0);
+
+ flatbuffers::Offset<QuantizationParameters> q_params = 0;
+
+ if (is_quantized) {
+ if (t.min != 0 || t.max != 0) {
+ if (t.type == TensorType_UINT8) {
+ std::tie(t.scale, t.zero_point) =
+ QuantizationParams<uint8_t>(t.min, t.max);
+ } else if (t.type == TensorType_INT32) {
+ std::tie(t.scale, t.zero_point) =
+ QuantizationParams<int32_t>(t.min, t.max);
+ } else {
+ LOG(FATAL) << "No support for the requested quantized type";
+ }
+ t.min = 0;
+ t.max = 0;
+ }
+
+ q_params = CreateQuantizationParameters(
+ builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>({t.scale}),
+ builder_.CreateVector<int64_t>({t.zero_point}));
+ }
+
+ tensors_.push_back(CreateTensor(builder_, builder_.CreateVector<int>({}),
+ t.type, /*buffer=*/0,
+ /*name=*/0, q_params));
+
+ tensor_data_[id] = t;
+
+ return id;
+}
+
+int SingleOpModel::AddInput(const TensorData& t) {
+ int id = AddTensor(t);
+ inputs_.push_back(id);
+ return id;
+}
+
+int SingleOpModel::AddNullInput() {
+ int id = kOptionalTensor;
+ inputs_.push_back(id);
+ return id;
+}
+
+int SingleOpModel::AddOutput(const TensorData& t) {
+ int id = AddTensor(t);
+ outputs_.push_back(id);
+ return id;
+}
+
+void SingleOpModel::SetBuiltinOp(BuiltinOperator type,
+ BuiltinOptions builtin_options_type,
+ flatbuffers::Offset<void> builtin_options) {
+ opcodes_.push_back(CreateOperatorCode(builder_, type, 0));
+ operators_.push_back(CreateOperator(
+ builder_, /*opcode_index=*/0, builder_.CreateVector<int32_t>(inputs_),
+ builder_.CreateVector<int32_t>(outputs_), builtin_options_type,
+ builtin_options,
+ /*custom_options=*/0, CustomOptionsFormat_FLEXBUFFERS));
+}
+
+void SingleOpModel::SetCustomOp(
+ const string& name, const std::vector<uint8_t>& custom_option,
+ const std::function<TfLiteRegistration*()>& registeration) {
+ custom_registrations_[name] = registeration;
+ opcodes_.push_back(
+ CreateOperatorCodeDirect(builder_, BuiltinOperator_CUSTOM, name.data()));
+ operators_.push_back(CreateOperator(
+ builder_, /*opcode_index=*/0, builder_.CreateVector<int32_t>(inputs_),
+ builder_.CreateVector<int32_t>(outputs_), BuiltinOptions_NONE, 0,
+ builder_.CreateVector<uint8_t>(custom_option),
+ CustomOptionsFormat_FLEXBUFFERS));
+}
+
+void SingleOpModel::BuildInterpreter(
+ std::vector<std::vector<int>> input_shapes) {
+ auto opcodes = builder_.CreateVector(opcodes_);
+ auto operators = builder_.CreateVector(operators_);
+ auto tensors = builder_.CreateVector(tensors_);
+ auto inputs = builder_.CreateVector<int32_t>(inputs_);
+ auto outputs = builder_.CreateVector<int32_t>(outputs_);
+ // Create a single subgraph
+ std::vector<flatbuffers::Offset<SubGraph>> subgraphs;
+ auto subgraph = CreateSubGraph(builder_, tensors, inputs, outputs, operators);
+ subgraphs.push_back(subgraph);
+ auto subgraphs_flatbuffer = builder_.CreateVector(subgraphs);
+
+ std::vector<flatbuffers::Offset<Buffer>> buffers_vec;
+ auto buffers = builder_.CreateVector(buffers_vec);
+ auto description = builder_.CreateString("programmatic model");
+ builder_.Finish(CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
+ subgraphs_flatbuffer, description, buffers));
+
+ auto* model = GetModel(builder_.GetBufferPointer());
+
+ ops::builtin::BuiltinOpResolver builtins;
+ for (const auto& reg : custom_registrations_) {
+ builtins.AddCustom(reg.first.data(), reg.second());
+ }
+ InterpreterBuilder(model, builtins)(&interpreter_);
+
+ CHECK(interpreter_ != nullptr);
+
+ int i = 0;
+ for (const auto& shape : input_shapes) {
+ int input_idx = interpreter_->inputs()[i++];
+ if (input_idx == kOptionalTensor) continue;
+ CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk);
+ }
+ CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
+ << "Cannot allocate tensors";
+}
+
+void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); }
+
+int32_t SingleOpModel::GetTensorSize(int index) const {
+ TfLiteTensor* t = interpreter_->tensor(index);
+ CHECK(t);
+ int total_size = 1;
+ for (int i = 0; i < t->dims->size; ++i) {
+ total_size *= t->dims->data[i];
+ }
+ return total_size;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
new file mode 100644
index 0000000000..e68e494661
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -0,0 +1,202 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
+
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tflite {
+
+inline void LogToStderr() {
+#ifdef PLATFORM_GOOGLE
+ FLAGS_logtostderr = true;
+#endif
+}
+
+// A gmock matcher that check that elements of a float vector match to a given
+// tolerance.
+std::vector<::testing::Matcher<float>> ArrayFloatNear(
+ const std::vector<float>& values, float max_abs_error = 1e-5);
+
+template <typename T>
+inline std::vector<T> Quantize(const std::vector<float>& data, float scale,
+ int32_t zero_point) {
+ std::vector<T> q;
+ for (float f : data) {
+ q.push_back(std::max(
+ std::numeric_limits<T>::min(),
+ std::min(std::numeric_limits<T>::max(),
+ static_cast<T>(std::round(zero_point + (f / scale))))));
+ }
+ return q;
+}
+
+template <typename T>
+inline std::vector<float> Dequantize(const std::vector<T>& data, float scale,
+ int32_t zero_point) {
+ std::vector<float> f;
+ for (T q : data) {
+ f.push_back(scale * (q - zero_point));
+ }
+ return f;
+}
+
+// A test model that contains a single operator. All operator inputs and
+// output are external to the model, so the tests can directly access them.
+// Typical usage:
+// SingleOpModel m;
+// int a = m.AddInput({TensorType_FLOAT32, a_shape});
+// int b = m.AddInput({TensorType_FLOAT32, b_shape});
+// int c = m.AddOutput({TensorType_FLOAT32, {}});
+// m.SetBuiltinOp(...);
+// m.BuildInterpreter({GetShape(a), GetShape(b)});
+// m.PopulateTensor(a, {...});
+// m.PopulateTensor(b, {...});
+// m.Invoke();
+// EXPECT_THAT(m.ExtractVector<float>(c), ArrayFloatNear({...}));
+//
+
+// A helper struct to construct test tensors. This is particularly useful for
+// quantized tensor which must have their scale and zero_point defined before
+// the actual data is known. This mimics what happens in practice: quantization
+// parameters are calculate during training.
+struct TensorData {
+ TensorType type;
+ std::vector<int> shape;
+ float min;
+ float max;
+ float scale;
+ int32_t zero_point;
+};
+
+class SingleOpModel {
+ public:
+ SingleOpModel() {}
+ ~SingleOpModel() {}
+
+ // Copying or assignment is disallowed to simplify ownership semantics.
+ SingleOpModel(const SingleOpModel&) = delete;
+ SingleOpModel& operator=(const SingleOpModel&) = delete;
+
+ // Add a TensorType input tensor and return its index.
+ int AddInput(TensorType type) { return AddInput(TensorData{type}); }
+ int AddInput(const TensorData& t);
+
+ // Add a null input tensor (optional input) and return kOptionalTensor.
+ int AddNullInput();
+
+ // Add a TensorType output tensor and return its index.
+ int AddOutput(TensorType type) { return AddOutput(TensorData{type}); }
+ int AddOutput(const TensorData& t);
+
+ template <typename T>
+ void QuantizeAndPopulate(int index, std::initializer_list<float> data) {
+ TfLiteTensor* t = interpreter_->tensor(index);
+ auto q = Quantize<T>(data, t->params.scale, t->params.zero_point);
+ PopulateTensor(index, 0, q.data(), q.data() + q.size());
+ }
+
+ const std::vector<int>& GetShape(int id) { return tensor_data_.at(id).shape; }
+
+ float GetScale(int id) { return tensor_data_.at(id).scale; }
+ int32_t GetZeroPoint(int id) { return tensor_data_.at(id).zero_point; }
+
+ // Define the operator in this model.
+ void SetBuiltinOp(BuiltinOperator type, BuiltinOptions builtin_options_type,
+ flatbuffers::Offset<void> builtin_options);
+ void SetCustomOp(const string& name,
+ const std::vector<uint8_t>& custom_option,
+ const std::function<TfLiteRegistration*()>& registeration);
+
+ // Build the interpreter for this model. Also, resize and allocate all
+ // tensors given the shapes of the inputs.
+ void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
+
+ void Invoke();
+
+ void PopulateStringTensor(int index, const std::vector<string>& content) {
+ auto tensor = interpreter_->tensor(index);
+ DynamicBuffer buf;
+ for (const string& s : content) {
+ buf.AddString(s.data(), s.length());
+ }
+ buf.WriteToTensor(tensor);
+ }
+
+ // Populate the tensor given its index.
+ template <typename T>
+ void PopulateTensor(int index, std::initializer_list<T> data) {
+ T* v = interpreter_->typed_tensor<T>(index);
+ CHECK(v) << "No tensor with index '" << index << "'.";
+ for (T f : data) {
+ *v = f;
+ ++v;
+ }
+ }
+
+ // Partially populate the tensor, starting at the given offset.
+ template <typename T>
+ void PopulateTensor(int index, int offset, T* begin, T* end) {
+ T* v = interpreter_->typed_tensor<T>(index);
+ memcpy(v + offset, begin, (end - begin) * sizeof(T));
+ }
+
+ // Return a vector with the flattened contents of a tensor.
+ template <typename T>
+ std::vector<T> ExtractVector(int index) {
+ T* v = interpreter_->typed_tensor<T>(index);
+ CHECK(v);
+ return std::vector<T>(v, v + GetTensorSize(index));
+ }
+
+ std::vector<int> GetTensorShape(int index) {
+ std::vector<int> result;
+ TfLiteTensor* t = interpreter_->tensor(index);
+ for (int i = 0; i < t->dims->size; ++i) {
+ result.push_back(t->dims->data[i]);
+ }
+ return result;
+ }
+
+ protected:
+ int32_t GetTensorSize(int index) const;
+
+ flatbuffers::FlatBufferBuilder builder_;
+ std::unique_ptr<tflite::Interpreter> interpreter_;
+
+ private:
+ int AddTensor(TensorData t);
+
+ std::map<int, TensorData> tensor_data_;
+ std::vector<int32_t> inputs_;
+ std::vector<int32_t> outputs_;
+ std::vector<flatbuffers::Offset<Tensor>> tensors_;
+ std::vector<flatbuffers::Offset<OperatorCode>> opcodes_;
+ std::vector<flatbuffers::Offset<Operator>> operators_;
+ std::map<string, std::function<TfLiteRegistration*()>> custom_registrations_;
+};
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
new file mode 100644
index 0000000000..f8208f6f98
--- /dev/null
+++ b/tensorflow/contrib/lite/model.cc
@@ -0,0 +1,673 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <fcntl.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "tensorflow/contrib/lite/allocation.h"
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/nnapi_delegate.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+
+const char* kEmptyTensorName = "";
+
+std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
+ const char* filename, ErrorReporter* error_reporter) {
+ std::unique_ptr<FlatBufferModel> model;
+ model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter,
+ /*use_nnapi=*/true));
+ if (!model->initialized()) model.reset();
+ return model;
+}
+
+std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
+ const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
+ std::unique_ptr<FlatBufferModel> model;
+ model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter));
+ if (!model->initialized()) model.reset();
+ return model;
+}
+
+FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file,
+ ErrorReporter* error_reporter, bool use_nnapi)
+ : error_reporter_(error_reporter ? error_reporter
+ : DefaultErrorReporter()) {
+ if (mmap_file) {
+ if (use_nnapi && NNAPIExists())
+ allocation_ = new NNAPIAllocation(filename, error_reporter);
+ else
+ allocation_ = new MMAPAllocation(filename, error_reporter);
+ } else {
+ allocation_ = new FileCopyAllocation(filename, error_reporter);
+ }
+ if (!allocation_->valid()) return;
+ if (!CheckModelIdentifier()) return;
+
+ model_ = ::tflite::GetModel(allocation_->base());
+}
+
+bool FlatBufferModel::CheckModelIdentifier() const {
+ if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
+ const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
+ error_reporter_->Report(
+ "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
+ ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
+ return false;
+ }
+ return true;
+}
+
+FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes,
+ ErrorReporter* error_reporter)
+ : error_reporter_(error_reporter ? error_reporter
+ : DefaultErrorReporter()) {
+ allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter);
+ if (!allocation_->valid()) return;
+ model_ = ::tflite::GetModel(allocation_->base());
+}
+
+FlatBufferModel::~FlatBufferModel() { delete allocation_; }
+
+InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model,
+ const OpResolver& op_resolver)
+ : model_(model.GetModel()),
+ op_resolver_(op_resolver),
+ error_reporter_(model.error_reporter()),
+ allocation_(model.allocation()) {}
+
+InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model,
+ const OpResolver& op_resolver,
+ ErrorReporter* error_reporter)
+ : model_(model),
+ op_resolver_(op_resolver),
+ error_reporter_(error_reporter ? error_reporter
+ : DefaultErrorReporter()) {}
+
+TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
+ TfLiteStatus status = kTfLiteOk;
+ auto opcodes = model_->operator_codes();
+ for (const OperatorCode* opcode : *opcodes) {
+ TfLiteRegistration* registration = nullptr;
+
+ if (opcode->builtin_code() != BuiltinOperator_CUSTOM) {
+ auto x = opcode->builtin_code();
+ flatbuffer_op_index_to_registration_types_.push_back(x);
+ registration = op_resolver_.FindOp(x);
+ if (registration == nullptr) {
+ error_reporter_->Report("Didn't find op for builtin opcode '%s'\n",
+ EnumNameBuiltinOperator(x));
+ status = kTfLiteError;
+ }
+ } else if (!opcode->custom_code()) {
+ error_reporter_->Report(
+ "Operator with builtin_code==0 has no custom_code.\n");
+ status = kTfLiteError;
+ } else {
+ const char* name = opcode->custom_code()->c_str();
+ registration = op_resolver_.FindOp(name);
+ flatbuffer_op_index_to_registration_types_.push_back(
+ BuiltinOperator_CUSTOM);
+ if (registration == nullptr) {
+ error_reporter_->Report("Didn't find custom op for name '%s'\n", name);
+ status = kTfLiteError;
+ }
+ }
+ flatbuffer_op_index_to_registration_.push_back(registration);
+ }
+ return status;
+}
+
+namespace {
+template <class T>
+std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
+ std::vector<int> ret(flat_array->Length());
+ for (int i = 0; i < flat_array->Length(); i++) {
+ ret[i] = flat_array->Get(i);
+ }
+ return ret;
+}
+
+// Allocate a structure using C malloc, but make sure the structure is a
+// POD structure that doesn't require constructors to run. The reason we do
+// this, is that Interpreter's C extension part will take ownership and wants
+// to use malloc() and free().
+template <class T>
+T* MallocPOD() {
+ static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
+ return static_cast<T*>(malloc(sizeof(T)));
+}
+
+// Parse the appropriate data out of the op.
+//
+// This handles builtin data explicitly as there are flatbuffer schemas.
+//
+// Returns memory that must be feed.
+void* ParseOpData(const Operator* op, BuiltinOperator op_type,
+ ErrorReporter* error_reporter) {
+ auto parse_padding = [](Padding padding) {
+ switch (padding) {
+ case Padding_SAME:
+ return kTfLitePaddingSame;
+ case Padding_VALID:
+ return kTfLitePaddingValid;
+ }
+ return kTfLitePaddingUnknown;
+ };
+ auto parse_activation = [](ActivationFunctionType activation) {
+ switch (activation) {
+ case ActivationFunctionType_NONE:
+ return kTfLiteActNone;
+ case ActivationFunctionType_RELU:
+ return kTfLiteActRelu;
+ case ActivationFunctionType_RELU1:
+ return kTfLiteActRelu1;
+ case ActivationFunctionType_RELU6:
+ return kTfLiteActRelu6;
+ case ActivationFunctionType_TANH:
+ return kTfLiteActTanh;
+ case ActivationFunctionType_SIGN_BIT:
+ return kTfLiteActSignBit;
+ }
+ return kTfLiteActNone;
+ };
+ auto parseLSHProjectionType = [](LSHProjectionType type) {
+ switch (type) {
+ case LSHProjectionType_SPARSE:
+ return kTfLiteLshProjectionSparse;
+ case LSHProjectionType_DENSE:
+ return kTfLiteLshProjectionDense;
+ default:
+ return kTfLiteLshProjectionUnknown;
+ }
+ };
+ auto parseCombinerType = [](CombinerType type) {
+ switch (type) {
+ case CombinerType_MEAN:
+ return kTfLiteCombinerTypeMean;
+ case CombinerType_SQRTN:
+ return kTfLiteCombinerTypeSqrtn;
+ case CombinerType_SUM:
+ default:
+ return kTfLiteCombinerTypeSum;
+ }
+ };
+
+ void* builtin_data = nullptr;
+ switch (op_type) {
+ case BuiltinOperator_CALL:
+ // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
+ // ok for now, since there is no call implementation either.
+ break;
+ case BuiltinOperator_CUSTOM:
+ break;
+ case BuiltinOperator_CONV_2D: {
+ TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
+ if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
+ params->padding = parse_padding(conv_params->padding());
+ params->stride_width = conv_params->stride_w();
+ params->stride_height = conv_params->stride_h();
+ params->activation =
+ parse_activation(conv_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_TANH:
+ case BuiltinOperator_LOGISTIC:
+ case BuiltinOperator_RELU:
+ case BuiltinOperator_RELU1:
+ case BuiltinOperator_RELU6:
+ case BuiltinOperator_CONCAT_EMBEDDINGS:
+ break;
+ case BuiltinOperator_LSH_PROJECTION: {
+ TfLiteLSHProjectionParams* params =
+ MallocPOD<TfLiteLSHProjectionParams>();
+ if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
+ params->type = parseLSHProjectionType(lshParams->type());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_AVERAGE_POOL_2D:
+ case BuiltinOperator_MAX_POOL_2D:
+ case BuiltinOperator_L2_POOL_2D: {
+ TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
+ if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
+ params->padding = parse_padding(pool_params->padding());
+ params->stride_width = pool_params->stride_w();
+ params->stride_height = pool_params->stride_h();
+ params->filter_width = pool_params->filter_width();
+ params->filter_height = pool_params->filter_height();
+ params->activation =
+ parse_activation(pool_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_DEPTHWISE_CONV_2D: {
+ TfLiteDepthwiseConvParams* params =
+ MallocPOD<TfLiteDepthwiseConvParams>();
+ if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
+ params->padding = parse_padding(conv_params->padding());
+ params->stride_width = conv_params->stride_w();
+ params->stride_height = conv_params->stride_h();
+ params->depth_multiplier = conv_params->depth_multiplier();
+ params->activation =
+ parse_activation(conv_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SVDF: {
+ TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
+ if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
+ params->rank = svdf_params->rank();
+ params->activation =
+ parse_activation(svdf_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RNN: {
+ TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
+ if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
+ params->activation =
+ parse_activation(rnn_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_EMBEDDING_LOOKUP:
+ // no-op.
+ break;
+ case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
+ TfLiteEmbeddingLookupSparseParams* params =
+ MallocPOD<TfLiteEmbeddingLookupSparseParams>();
+ if (auto* embedding_params =
+ op->builtin_options_as_EmbeddingLookupSparseOptions()) {
+ params->combiner = parseCombinerType(embedding_params->combiner());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_FULLY_CONNECTED: {
+ TfLiteFullyConnectedParams* params =
+ MallocPOD<TfLiteFullyConnectedParams>();
+ if (auto* fully_connected_params =
+ op->builtin_options_as_FullyConnectedOptions()) {
+ params->activation = parse_activation(
+ fully_connected_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_HASHTABLE_LOOKUP:
+ // no-op.
+ break;
+ case BuiltinOperator_SOFTMAX: {
+ TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
+ if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
+ params->beta = softmax_params->beta();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_CONCATENATION: {
+ TfLiteConcatenationParams* params =
+ MallocPOD<TfLiteConcatenationParams>();
+ if (auto* concatenation_params =
+ op->builtin_options_as_ConcatenationOptions()) {
+ params->activation =
+ parse_activation(concatenation_params->fused_activation_function());
+ params->axis = concatenation_params->axis();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_MUL: {
+ auto* params = MallocPOD<TfLiteMulParams>();
+ if (auto* schema_params = op->builtin_options_as_MulOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ADD: {
+ auto* params = MallocPOD<TfLiteAddParams>();
+ if (auto* schema_params = op->builtin_options_as_AddOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_L2_NORMALIZATION: {
+ auto* params = MallocPOD<TfLiteL2NormParams>();
+ if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
+ auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
+ if (auto* schema_params =
+ op->builtin_options_as_LocalResponseNormalizationOptions()) {
+ params->radius = schema_params->radius();
+ params->bias = schema_params->bias();
+ params->alpha = schema_params->alpha();
+ params->beta = schema_params->beta();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_LSTM: {
+ TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
+ if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
+ params->activation =
+ parse_activation(lstm_params->fused_activation_function());
+ params->cell_clip = lstm_params->cell_clip();
+ params->proj_clip = lstm_params->proj_clip();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RESIZE_BILINEAR: {
+ auto* params = MallocPOD<TfLiteResizeBilinearParams>();
+ if (auto* schema_params =
+ op->builtin_options_as_ResizeBilinearOptions()) {
+ params->new_height = schema_params->new_height();
+ params->new_width = schema_params->new_width();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RESHAPE: {
+ auto* params = MallocPOD<TfLiteReshapeParams>();
+ if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
+ auto* new_shape = schema_params->new_shape();
+ if (!new_shape) {
+ error_reporter->Report("No new_shape provided for Reshape\n");
+ } else {
+ params->num_dimensions = new_shape->Length();
+ if (params->num_dimensions > sizeof(params->shape) / sizeof(int)) {
+ error_reporter->Report(
+ "Found too many dimensions in Reshape's new_shape\n");
+ } else {
+ for (int i = 0; i < params->num_dimensions; ++i) {
+ params->shape[i] = new_shape->Get(i);
+ }
+ }
+ }
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SKIP_GRAM: {
+ TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
+ if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
+ params->ngram_size = skip_gram_params->ngram_size();
+ params->max_skip_size = skip_gram_params->max_skip_size();
+ params->include_all_ngrams = skip_gram_params->include_all_ngrams();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SPACE_TO_DEPTH: {
+ auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
+ if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
+ params->block_size = schema_params->block_size();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ }
+ return builtin_data;
+}
+
+} // namespace
+
+TfLiteStatus InterpreterBuilder::ParseNodes(
+ const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
+ Interpreter* interpreter) {
+ TfLiteStatus status = kTfLiteOk;
+ for (int i = 0; i < operators->Length(); ++i) {
+ const auto* op = operators->Get(i);
+ int index = op->opcode_index();
+ if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) {
+ error_reporter_->Report("Missing registration for opcode_index %d\n",
+ index);
+ status = kTfLiteError;
+ continue;
+ }
+ const TfLiteRegistration* reg =
+ flatbuffer_op_index_to_registration_[op->opcode_index()];
+ if (reg == nullptr) {
+ error_reporter_->Report("Skipping op for opcode_index %d\n", index);
+ status = kTfLiteError;
+ continue;
+ }
+
+ auto op_type =
+ flatbuffer_op_index_to_registration_types_[op->opcode_index()];
+ if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
+ error_reporter_->Report(
+ "Found builtin operator %s with custom options.\n",
+ EnumNameBuiltinOperator(op_type));
+ }
+ if (op->custom_options()) {
+ interpreter->AddNodeWithParameters(
+ FlatBufferIntArrayToVector(op->inputs()),
+ FlatBufferIntArrayToVector(op->outputs()),
+ reinterpret_cast<const char*>(op->custom_options()->data()),
+ op->custom_options()->size(), nullptr, reg);
+ } else {
+ interpreter->AddNodeWithParameters(
+ FlatBufferIntArrayToVector(op->inputs()),
+ FlatBufferIntArrayToVector(op->outputs()), nullptr, 0,
+ ParseOpData(op, op_type, error_reporter_), reg);
+ }
+ }
+
+ return status;
+}
+
+TfLiteStatus InterpreterBuilder::ParseTensors(
+ const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
+ const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
+ Interpreter* interpreter) {
+ TfLiteStatus status = kTfLiteOk;
+
+ // A little helper to get the names of inputs and outputs. Note that they
+ // must outlive the interpreter.
+ auto get_name = [](const tflite::Tensor* t) -> const char* {
+ auto name = t->name();
+ if (name) return name->c_str();
+ return kEmptyTensorName;
+ };
+
+ for (int i = 0; i < tensors->Length(); ++i) {
+ const auto* tensor = tensors->Get(i);
+ std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape());
+
+ TfLiteQuantizationParams quantization;
+ quantization.scale = 0;
+ quantization.zero_point = 0;
+ auto* q_params = tensor->quantization();
+ if (q_params) {
+ // Note that the schema could hold per-channel quantization parameters
+ // but we really only support one value for the whole tensor.
+ // TODO(aselle): This breaks as well if these are nullptr's.
+ // TODO(aselle): This assumes non per-channel quantization.
+ if (q_params->scale()) quantization.scale = q_params->scale()->Get(0);
+ if (q_params->zero_point())
+ quantization.zero_point = q_params->zero_point()->Get(0);
+ }
+
+ TfLiteType type;
+ switch (tensor->type()) {
+ case TensorType_FLOAT32:
+ type = kTfLiteFloat32;
+ break;
+ case TensorType_INT32:
+ type = kTfLiteInt32;
+ break;
+ case TensorType_UINT8:
+ type = kTfLiteUInt8;
+ break;
+ case TensorType_INT64:
+ type = kTfLiteInt64;
+ break;
+ case TensorType_STRING:
+ type = kTfLiteString;
+ break;
+ default:
+ // tensorType = ArrayType::NONE;
+ error_reporter_->Report("Unimplemented data type %s (%d) in tensor\n",
+ EnumNameTensorType(tensor->type()),
+ tensor->type());
+ status = kTfLiteError;
+ continue;
+ }
+ auto get_readonly_data = [&](const char** buffer_data,
+ size_t* buffer_size) {
+ // TODO(aselle): Check what happens if we have an unspecified size
+ // constant.
+ *buffer_data = nullptr;
+ if (tensor->buffer() == 0) return kTfLiteOk;
+ if (tensor->buffer() >= buffers->size()) {
+ error_reporter_->Report(
+ "Tensor %d specifies out of range buffer %d (only %d buffers).\n",
+ i, tensor->buffer(), buffers->size());
+ return kTfLiteError;
+ }
+ if (auto* buffer = (*buffers)[tensor->buffer()]) {
+ if (auto* array = buffer->data()) {
+ if (size_t size = array->size()) {
+ *buffer_size = size;
+ *buffer_data = reinterpret_cast<const char*>(array->data());
+ return kTfLiteOk;
+ }
+ }
+ }
+ return kTfLiteOk;
+ };
+ size_t buffer_size = 0;
+ const char* buffer_ptr;
+ TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
+
+ if (buffer_ptr) {
+ if (interpreter->SetTensorParametersReadOnly(
+ i, type, get_name(tensor), dims, quantization, buffer_ptr,
+ buffer_size, allocation_) != kTfLiteOk) {
+ error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
+ i);
+ status = kTfLiteError;
+ }
+ } else {
+ if (interpreter->SetTensorParametersReadWrite(
+ i, type, get_name(tensor), dims, quantization) != kTfLiteOk) {
+ error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
+ i);
+ status = kTfLiteError;
+ }
+ }
+ }
+
+ return status;
+}
+
+TfLiteStatus InterpreterBuilder::operator()(
+ std::unique_ptr<Interpreter>* interpreter) {
+ if (!interpreter) {
+ error_reporter_->Report(
+ "Null output pointer passed to InterpreterBuilder.");
+ return kTfLiteError;
+ }
+
+ // Safe exit by deleting partially created interpreter, to reduce verbosity
+ // on error conditions. Use by return cleanup_on_error();
+ auto cleanup_and_error = [&interpreter]() {
+ interpreter->reset();
+ return kTfLiteError;
+ };
+
+ if (!model_) {
+ error_reporter_->Report("Null pointer passed in as model.");
+ return cleanup_and_error();
+ }
+
+ if (model_->version() != TFLITE_SCHEMA_VERSION) {
+ error_reporter_->Report(
+ "Model provided is schema version %d not equal "
+ "to supported version %d.\n",
+ model_->version(), TFLITE_SCHEMA_VERSION);
+ return cleanup_and_error();
+ }
+
+ if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
+ error_reporter_->Report("Registration failed.\n");
+ return cleanup_and_error();
+ }
+
+ // Flatbuffer model schemas define a list of opcodes independent of the graph.
+ // We first map those to registrations. This reduces string lookups for custom
+ // ops since we only do it once per custom op rather than once per custom op
+ // invocation in the model graph.
+ // Construct interpreter with correct number of tensors and operators.
+ auto* subgraphs = model_->subgraphs();
+ auto* buffers = model_->buffers();
+ if (subgraphs->size() != 1) {
+ error_reporter_->Report("Only 1 subgraph is currently supported.\n");
+ return cleanup_and_error();
+ }
+ const tflite::SubGraph* subgraph = (*subgraphs)[0];
+ auto operators = subgraph->operators();
+ auto tensors = subgraph->tensors();
+ if (!operators || !tensors || !buffers) {
+ error_reporter_->Report(
+ "Did not get operators, tensors, or buffers in input flat buffer.\n");
+ return cleanup_and_error();
+ }
+ interpreter->reset(new Interpreter(error_reporter_));
+ if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) {
+ return cleanup_and_error();
+ }
+
+ // Parse inputs/outputs
+ (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
+ (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));
+
+ // Finally setup nodes and tensors
+ if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)
+ return cleanup_and_error();
+ if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
+ return cleanup_and_error();
+
+ return kTfLiteOk;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
new file mode 100644
index 0000000000..15659d33f3
--- /dev/null
+++ b/tensorflow/contrib/lite/model.h
@@ -0,0 +1,165 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Deserialization infrastructure for tflite. Provides functionality
+// to go from a serialized tflite model in flatbuffer format to an
+// interpreter.
+//
+// using namespace tflite;
+// StderrReporter error_reporter;
+// auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite",
+// &error_reporter);
+// MyOpResolver resolver; // You need to subclass OpResolver to provide
+// // implementations.
+// InterpreterBuilder builder(*model, resolver);
+// std::unique_ptr<Interpreter> interpreter;
+// if(builder(&interpreter) == kTfLiteOk) {
+// .. run model inference with interpreter
+// }
+//
+// OpResolver must be defined to provide your kernel implementations to the
+// interpreter. This is environment specific and may consist of just the builtin
+// ops, or some custom operators you defined to extend tflite.
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_
+
+#include <memory>
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// An RAII object that represents a read-only tflite model, copied from disk,
+// or mmapped. This uses flatbuffers as the serialization format.
+class FlatBufferModel {
+ public:
+ // Build a model based on a file. Return a nullptr in case of failure.
+ static std::unique_ptr<FlatBufferModel> BuildFromFile(
+ const char* filename,
+ ErrorReporter* error_reporter = DefaultErrorReporter());
+
+ // Build a model based on a pre-loaded flatbuffer. The caller retains
+ // ownership of the buffer and should keep it alive until the returned object
+ // is destroyed. Return a nullptr in case of failure.
+ static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
+ const char* buffer, size_t buffer_size,
+ ErrorReporter* error_reporter = DefaultErrorReporter());
+
+ // Releases memory or unmaps mmaped meory.
+ ~FlatBufferModel();
+
+ // Copying or assignment is disallowed to simplify ownership semantics.
+ FlatBufferModel(const FlatBufferModel&) = delete;
+ FlatBufferModel& operator=(const FlatBufferModel&) = delete;
+
+ bool initialized() const { return model_ != nullptr; }
+ const tflite::Model* operator->() const { return model_; }
+ const tflite::Model* GetModel() const { return model_; }
+ ErrorReporter* error_reporter() const { return error_reporter_; }
+ const Allocation* allocation() const { return allocation_; }
+
+ // Returns true if the model identifier is correct (otherwise false and
+ // reports an error).
+ bool CheckModelIdentifier() const;
+
+ private:
+ // Load a model from `filename`. If `mmap_file` is true then use mmap,
+ // otherwise make a copy of the model in a buffer.
+ //
+ // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
+ // used.
+ explicit FlatBufferModel(
+ const char* filename, bool mmap_file = true,
+ ErrorReporter* error_reporter = DefaultErrorReporter(),
+ bool use_nnapi = false);
+
+ // Load a model from `ptr` and `num_bytes` of the model file. The `ptr` has to
+ // remain alive and unchanged until the end of this flatbuffermodel's
+ // lifetime.
+ //
+ // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
+ // used.
+ FlatBufferModel(const char* ptr, size_t num_bytes,
+ ErrorReporter* error_reporter = DefaultErrorReporter());
+
+ // Flatbuffer traverser pointer. (Model* is a pointer that is within the
+ // allocated memory of the data allocated by allocation's internals.
+ const tflite::Model* model_ = nullptr;
+ ErrorReporter* error_reporter_;
+ Allocation* allocation_ = nullptr;
+};
+
+// Abstract interface that returns TfLiteRegistrations given op codes or custom
+// op names. This is the mechanism that ops being referenced in the flatbuffer
+// model are mapped to executable function pointers (TfLiteRegistrations).
+class OpResolver {
+ public:
+ // Find the op registration for a builtin operator by enum code.
+ virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0;
+ // Find the op registration of a custom operator by op name.
+ virtual TfLiteRegistration* FindOp(const char* op) const = 0;
+ virtual ~OpResolver() {}
+};
+
+// Build an interpreter capable of interpreting `model`.
+//
+// model: a scoped model whose lifetime must be at least as long as
+// the interpreter. In principle multiple interpreters can be made from
+// a single model.
+// op_resolver: An instance that implements the Resolver interface which maps
+// custom op names and builtin op codes to op registrations.
+// reportError: a functor that is called to report errors that handles
+// printf var arg semantics. The lifetime of the reportError object must
+// be greater than or equal to the Interpreter created by operator().
+//
+// Returns a kTfLiteOk when successful and sets interpreter to a valid
+// Interpreter. Note: the user must ensure the model lifetime is at least as
+// long as interpreter's lifetime.
+class InterpreterBuilder {
+ public:
+ InterpreterBuilder(const FlatBufferModel& model,
+ const OpResolver& op_resolver);
+ // Build an interpreter given only the raw flatbuffer Model object (instead
+ // of a FlatBufferModel). Mostly used for testing.
+ // If `error_reporter` is null, then DefaultErrorReporter() is used.
+ InterpreterBuilder(const ::tflite::Model* model,
+ const OpResolver& op_resolver,
+ ErrorReporter* error_reporter = DefaultErrorReporter());
+ InterpreterBuilder(const InterpreterBuilder&) = delete;
+ InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
+ TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
+
+ private:
+ TfLiteStatus BuildLocalIndexToRegistrationMapping();
+ TfLiteStatus ParseNodes(
+ const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
+ Interpreter* interpreter);
+ TfLiteStatus ParseTensors(
+ const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
+ const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
+ Interpreter* interpreter);
+
+ const ::tflite::Model* model_;
+ const OpResolver& op_resolver_;
+ ErrorReporter* error_reporter_;
+
+ std::vector<TfLiteRegistration*> flatbuffer_op_index_to_registration_;
+ std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_;
+ const Allocation* allocation_ = nullptr;
+};
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
new file mode 100644
index 0000000000..ae823650d6
--- /dev/null
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -0,0 +1,258 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <fcntl.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "tensorflow/contrib/lite/model.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/error_reporter.h"
+
+// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
+// we must declare this in global namespace, so argument-dependent operator
+// lookup works.
+inline bool operator==(const TfLiteRegistration& a,
+ const TfLiteRegistration& b) {
+ return a.invoke == b.invoke && a.init == b.init && a.prepare == b.prepare &&
+ a.free == b.free;
+}
+
+namespace tflite {
+
+// Provide a dummy operation that does nothing.
+namespace {
+void* dummy_init(TfLiteContext*, const char*, size_t) { return nullptr; }
+void dummy_free(TfLiteContext*, void*) {}
+TfLiteStatus dummy_resize(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; }
+TfLiteStatus dummy_invoke(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; }
+TfLiteRegistration dummy_reg = {dummy_init, dummy_free, dummy_resize,
+ dummy_invoke};
+} // namespace
+
+// Provide a trivial resolver that returns a constant value no matter what
+// op is asked for.
+class TrivialResolver : public OpResolver {
+ public:
+ explicit TrivialResolver(TfLiteRegistration* constant_return = nullptr)
+ : constant_return_(constant_return) {}
+ // Find the op registration of a custom operator by op name.
+ TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override {
+ return constant_return_;
+ }
+ // Find the op registration of a custom operator by op name.
+ TfLiteRegistration* FindOp(const char* op) const override {
+ return constant_return_;
+ }
+
+ private:
+ TfLiteRegistration* constant_return_;
+};
+
+TEST(BasicFlatBufferModel, TestNonExistantFiles) {
+ ASSERT_TRUE(!FlatBufferModel::BuildFromFile("/tmp/tflite_model_1234"));
+}
+
+// Make sure a model with nothing in it loads properly.
+TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) {
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/empty_model.bin");
+ ASSERT_TRUE(model);
+ // Now try to build it into a model.
+ std::unique_ptr<Interpreter> interpreter;
+ ASSERT_EQ(InterpreterBuilder(*model, TrivialResolver())(&interpreter),
+ kTfLiteOk);
+ ASSERT_NE(interpreter, nullptr);
+ ASSERT_NE(InterpreterBuilder(*model, TrivialResolver())(nullptr), kTfLiteOk);
+}
+
+// Make sure currently unsupported # of subgraphs are checked
+// TODO(aselle): Replace this test when multiple subgraphs are supported.
+TEST(BasicFlatBufferModel, TestZeroAndMultipleSubgraphs) {
+ auto m1 = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/0_subgraphs.bin");
+ ASSERT_TRUE(m1);
+ std::unique_ptr<Interpreter> interpreter1;
+ ASSERT_NE(InterpreterBuilder(*m1, TrivialResolver())(&interpreter1),
+ kTfLiteOk);
+
+ auto m2 = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/2_subgraphs.bin");
+ ASSERT_TRUE(m2);
+ std::unique_ptr<Interpreter> interpreter2;
+ ASSERT_NE(InterpreterBuilder(*m2, TrivialResolver())(&interpreter2),
+ kTfLiteOk);
+}
+
+// Test what happens if we cannot bind any of the ops.
+TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) {
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model.bin");
+ ASSERT_TRUE(model);
+ // Check that we get an error code and interpreter pointer is reset.
+ std::unique_ptr<Interpreter> interpreter(new Interpreter);
+ ASSERT_NE(InterpreterBuilder(*model, TrivialResolver(nullptr))(&interpreter),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter, nullptr);
+}
+
+// Make sure model is read to interpreter propelrly
+TEST(BasicFlatBufferModel, TestModelInInterpreter) {
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model.bin");
+ ASSERT_TRUE(model);
+ // Check that we get an error code and interpreter pointer is reset.
+ std::unique_ptr<Interpreter> interpreter(new Interpreter);
+ ASSERT_EQ(
+ InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter),
+ kTfLiteOk);
+ ASSERT_NE(interpreter, nullptr);
+ ASSERT_EQ(interpreter->tensors_size(), 4);
+ ASSERT_EQ(interpreter->nodes_size(), 2);
+ std::vector<int> inputs = {0, 1};
+ std::vector<int> outputs = {2, 3};
+ ASSERT_EQ(interpreter->inputs(), inputs);
+ ASSERT_EQ(interpreter->outputs(), outputs);
+
+ EXPECT_EQ(std::string(interpreter->GetInputName(0)), "input0");
+ EXPECT_EQ(std::string(interpreter->GetInputName(1)), "input1");
+ EXPECT_EQ(std::string(interpreter->GetOutputName(0)), "out1");
+ EXPECT_EQ(std::string(interpreter->GetOutputName(1)), "out2");
+
+ // Make sure all input tensors are correct
+ TfLiteTensor* i0 = interpreter->tensor(0);
+ ASSERT_EQ(i0->type, kTfLiteFloat32);
+ ASSERT_NE(i0->data.raw, nullptr); // mmapped
+ ASSERT_EQ(i0->allocation_type, kTfLiteMmapRo);
+ TfLiteTensor* i1 = interpreter->tensor(1);
+ ASSERT_EQ(i1->type, kTfLiteFloat32);
+ ASSERT_EQ(i1->data.raw, nullptr);
+ ASSERT_EQ(i1->allocation_type, kTfLiteArenaRw);
+ TfLiteTensor* o0 = interpreter->tensor(2);
+ ASSERT_EQ(o0->type, kTfLiteFloat32);
+ ASSERT_EQ(o0->data.raw, nullptr);
+ ASSERT_EQ(o0->allocation_type, kTfLiteArenaRw);
+ TfLiteTensor* o1 = interpreter->tensor(3);
+ ASSERT_EQ(o1->type, kTfLiteFloat32);
+ ASSERT_EQ(o1->data.raw, nullptr);
+ ASSERT_EQ(o1->allocation_type, kTfLiteArenaRw);
+
+ // Check op 0 which has inputs {0, 1} outputs {2}.
+ {
+ const std::pair<TfLiteNode, TfLiteRegistration>* node_and_reg0 =
+ interpreter->node_and_registration(0);
+ ASSERT_NE(node_and_reg0, nullptr);
+ const TfLiteNode& node0 = node_and_reg0->first;
+ const TfLiteRegistration& reg0 = node_and_reg0->second;
+ TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(2);
+ desired_inputs->data[0] = 0;
+ desired_inputs->data[1] = 1;
+ TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1);
+ desired_outputs->data[0] = 2;
+ ASSERT_TRUE(TfLiteIntArrayEqual(node0.inputs, desired_inputs));
+ ASSERT_TRUE(TfLiteIntArrayEqual(node0.outputs, desired_outputs));
+ TfLiteIntArrayFree(desired_inputs);
+ TfLiteIntArrayFree(desired_outputs);
+ ASSERT_EQ(reg0, dummy_reg);
+ }
+
+ // Check op 1 which has inputs {2} outputs {3}.
+ {
+ const std::pair<TfLiteNode, TfLiteRegistration>* node_and_reg1 =
+ interpreter->node_and_registration(1);
+ ASSERT_NE(node_and_reg1, nullptr);
+ const TfLiteNode& node1 = node_and_reg1->first;
+ const TfLiteRegistration& reg1 = node_and_reg1->second;
+ TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(1);
+ TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1);
+ desired_inputs->data[0] = 2;
+ desired_outputs->data[0] = 3;
+ ASSERT_TRUE(TfLiteIntArrayEqual(node1.inputs, desired_inputs));
+ ASSERT_TRUE(TfLiteIntArrayEqual(node1.outputs, desired_outputs));
+ TfLiteIntArrayFree(desired_inputs);
+ TfLiteIntArrayFree(desired_outputs);
+ ASSERT_EQ(reg1, dummy_reg);
+ }
+}
+
+// This tests on a flatbuffer that defines a shape of 2 to be a memory mapped
+// buffer. But the buffer is provided to be only 1 element.
+TEST(BasicFlatBufferModel, TestBrokenMmap) {
+ ASSERT_FALSE(FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model_broken.bin"));
+}
+
+TEST(BasicFlatBufferModel, TestNullModel) {
+ // Check that we get an error code and interpreter pointer is reset.
+ std::unique_ptr<Interpreter> interpreter(new Interpreter);
+ ASSERT_NE(
+ InterpreterBuilder(nullptr, TrivialResolver(&dummy_reg))(&interpreter),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.get(), nullptr);
+}
+
+struct TestErrorReporter : public ErrorReporter {
+ int Report(const char* format, va_list args) override {
+ calls++;
+ return 0;
+ }
+ int calls = 0;
+};
+
+// This makes sure the ErrorReporter is marshalled from FlatBufferModel to
+// the Interpreter.
+TEST(BasicFlatBufferModel, TestCustomErrorReporter) {
+ TestErrorReporter reporter;
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/empty_model.bin",
+ &reporter);
+ ASSERT_TRUE(model);
+
+ std::unique_ptr<Interpreter> interpreter;
+ TrivialResolver resolver;
+ InterpreterBuilder(*model, resolver)(&interpreter);
+ ASSERT_NE(interpreter->Invoke(), kTfLiteOk);
+ ASSERT_EQ(reporter.calls, 1);
+}
+
+// This makes sure the ErrorReporter is marshalled from FlatBufferModel to
+// the Interpreter.
+TEST(BasicFlatBufferModel, TestNullErrorReporter) {
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/empty_model.bin", nullptr);
+ ASSERT_TRUE(model);
+
+ std::unique_ptr<Interpreter> interpreter;
+ TrivialResolver resolver;
+ InterpreterBuilder(*model, resolver)(&interpreter);
+ ASSERT_NE(interpreter->Invoke(), kTfLiteOk);
+}
+
+// TODO(aselle): Add tests for serialization of builtin op data types.
+// These tests will occur with the evaluation tests of individual operators,
+// not here.
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD
new file mode 100644
index 0000000000..fbdf19f205
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/BUILD
@@ -0,0 +1,15 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc
new file mode 100644
index 0000000000..1c422b659a
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc
@@ -0,0 +1,119 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Convert a list of strings to integers via hashing.
+// Input:
+// Input[0]: A list of ngrams. string[num of input]
+//
+// Output:
+// Output[0]: Hashed features. int32[num of input]
+// Output[1]: Weights. float[num of input]
+
+#include <algorithm>
+#include <map>
+#include "re2/re2.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/string_util.h"
+#include <farmhash.h>
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace extract {
+
+static const int kMaxDimension = 1000000;
+static const std::vector<string> kBlacklistNgram = {"<S>", "<E>", "<S> <E>"};
+
+bool Equals(const string& x, const tflite::StringRef& strref) {
+ if (strref.len != x.length()) {
+ return false;
+ }
+ if (strref.len > 0) {
+ int r = memcmp(strref.str, x.data(), strref.len);
+ return r == 0;
+ }
+ return true;
+}
+
+bool IsValidNgram(const tflite::StringRef& strref) {
+ for (const auto& s : kBlacklistNgram) {
+ if (Equals(s, strref)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteIntArray* outputSize1 = TfLiteIntArrayCreate(1);
+ TfLiteIntArray* outputSize2 = TfLiteIntArrayCreate(1);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ int dim = input->dims->data[0];
+ if (dim == 0) {
+ // TFLite non-string output should have size greater than 0.
+ dim = 1;
+ }
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteString);
+ outputSize1->data[0] = dim;
+ outputSize2->data[0] = dim;
+ context->ResizeTensor(context, GetOutput(context, node, 0), outputSize1);
+ context->ResizeTensor(context, GetOutput(context, node, 1), outputSize2);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, 0);
+ int num_strings = tflite::GetStringCount(input);
+ TfLiteTensor* label = GetOutput(context, node, 0);
+ TfLiteTensor* weight = GetOutput(context, node, 1);
+
+ std::map<int64, int> feature_id_counts;
+ for (int i = 0; i < num_strings; i++) {
+ // Use fingerprint of feature name as id.
+ auto strref = tflite::GetString(input, i);
+ if (!IsValidNgram(strref)) {
+ label->data.i32[i] = 0;
+ weight->data.i32[i] = 0;
+ continue;
+ }
+
+ int64 feature_id =
+ ::util::Fingerprint64(strref.str, strref.len) % kMaxDimension;
+
+ label->data.i32[i] = static_cast<int32>(feature_id);
+ weight->data.f[i] =
+ std::count(strref.str, strref.str + strref.len, ' ') + 1;
+ }
+ // Explicitly set an empty result to make preceding ops run.
+ if (num_strings == 0) {
+ label->data.i32[0] = 0;
+ weight->data.i32[0] = 0;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace extract
+
+TfLiteRegistration* Register_EXTRACT_FEATURES() {
+ static TfLiteRegistration r = {nullptr, nullptr, extract::Prepare,
+ extract::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc
new file mode 100644
index 0000000000..9b8676bab6
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc
@@ -0,0 +1,100 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include <farmhash.h>
+
+namespace tflite {
+
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_EXTRACT_FEATURES();
+
+namespace {
+
+using ::testing::ElementsAre;
+
+class ExtractFeatureOpModel : public SingleOpModel {
+ public:
+ explicit ExtractFeatureOpModel(const std::vector<string>& input) {
+ input_ = AddInput(TensorType_STRING);
+ signature_ = AddOutput(TensorType_INT32);
+ weight_ = AddOutput(TensorType_FLOAT32);
+
+ SetCustomOp("ExtractFeatures", {}, Register_EXTRACT_FEATURES);
+ BuildInterpreter({{static_cast<int>(input.size())}});
+ PopulateStringTensor(input_, input);
+ }
+
+ std::vector<int> GetSignature() { return ExtractVector<int>(signature_); }
+ std::vector<float> GetWeight() { return ExtractVector<float>(weight_); }
+
+ private:
+ int input_;
+ int signature_;
+ int weight_;
+};
+
+int CalcFeature(const string& str) {
+ return ::util::Fingerprint64(str) % 1000000;
+}
+
+TEST(ExtractFeatureOpTest, RegularInput) {
+ ExtractFeatureOpModel m({"<S>", "<S> Hi", "Hi", "Hi !", "!", "! <E>", "<E>"});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(),
+ ElementsAre(0, CalcFeature("<S> Hi"), CalcFeature("Hi"),
+ CalcFeature("Hi !"), CalcFeature("!"),
+ CalcFeature("! <E>"), 0));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(0, 2, 1, 2, 1, 2, 0));
+}
+
+TEST(ExtractFeatureOpTest, OneInput) {
+ ExtractFeatureOpModel m({"Hi"});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(), ElementsAre(CalcFeature("Hi")));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(1));
+}
+
+TEST(ExtractFeatureOpTest, ZeroInput) {
+ ExtractFeatureOpModel m({});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(), ElementsAre(0));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(0));
+}
+
+TEST(ExtractFeatureOpTest, AllBlacklistInput) {
+ ExtractFeatureOpModel m({"<S>", "<E>"});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(), ElementsAre(0, 0));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(0, 0));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc
new file mode 100644
index 0000000000..d0dc2a35a7
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc
@@ -0,0 +1,105 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Normalize the string input.
+//
+// Input:
+// Input[0]: One sentence. string[1]
+//
+// Output:
+// Output[0]: Normalized sentence. string[1]
+//
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/strip.h"
+#include "re2/re2.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace normalize {
+
+// Predictor transforms.
+const char kPunctuationsRegex[] = "[.*()\"]";
+
+const std::map<string, string>* kRegexTransforms =
+ new std::map<string, string>({
+ {"([^\\s]+)n't", "\\1 not"},
+ {"([^\\s]+)'nt", "\\1 not"},
+ {"([^\\s]+)'ll", "\\1 will"},
+ {"([^\\s]+)'re", "\\1 are"},
+ {"([^\\s]+)'ve", "\\1 have"},
+ {"i'm", "i am"},
+ });
+
+static const char kStartToken[] = "<S>";
+static const char kEndToken[] = "<E>";
+static const int32 kMaxInputChars = 300;
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ tflite::StringRef input = tflite::GetString(GetInput(context, node, 0), 0);
+
+ string result(absl::AsciiStrToLower(absl::string_view(input.str, input.len)));
+ absl::StripAsciiWhitespace(&result);
+ // Do not remove commas, semi-colons or colons from the sentences as they can
+ // indicate the beginning of a new clause.
+ RE2::GlobalReplace(&result, kPunctuationsRegex, "");
+ RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)([\\s,;:/])",
+ "\\1\\2");
+ RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)$", "\\1");
+ for (auto iter = kRegexTransforms->begin(); iter != kRegexTransforms->end();
+ iter++) {
+ RE2::GlobalReplace(&result, iter->first, iter->second);
+ }
+
+ // Treat questions & interjections as special cases.
+ RE2::GlobalReplace(&result, "([?])+", "\\1");
+ RE2::GlobalReplace(&result, "([!])+", "\\1");
+ RE2::GlobalReplace(&result, "([^?!]+)([?!])", "\\1 \\2 ");
+ RE2::GlobalReplace(&result, "([?!])([?!])", "\\1 \\2");
+
+ RE2::GlobalReplace(&result, "[\\s,:;\\-&'\"]+$", "");
+ RE2::GlobalReplace(&result, "^[\\s,:;\\-&'\"]+", "");
+ absl::StripAsciiWhitespace(&result);
+
+ // Add start and end token.
+ // Truncate input to maximum allowed size.
+ if (result.length() <= kMaxInputChars) {
+ absl::StrAppend(&result, " ", kEndToken);
+ } else {
+ result = result.substr(0, kMaxInputChars);
+ }
+ result = absl::StrCat(kStartToken, " ", result);
+
+ tflite::DynamicBuffer buf;
+ buf.AddString(result.data(), result.length());
+ buf.WriteToTensor(GetOutput(context, node, 0));
+ return kTfLiteOk;
+}
+
+} // namespace normalize
+
+TfLiteRegistration* Register_NORMALIZE() {
+ static TfLiteRegistration r = {nullptr, nullptr, nullptr, normalize::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc
new file mode 100644
index 0000000000..4d35dba9a6
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_NORMALIZE();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class NormalizeOpModel : public SingleOpModel {
+ public:
+ explicit NormalizeOpModel(const string& input) {
+ input_ = AddInput(TensorType_STRING);
+ output_ = AddOutput(TensorType_STRING);
+
+ SetCustomOp("Normalize", {}, Register_NORMALIZE);
+ BuildInterpreter({{static_cast<int>(input.size())}});
+ PopulateStringTensor(input_, {input});
+ }
+
+ std::vector<string> GetStringOutput() {
+ TfLiteTensor* output = interpreter_->tensor(output_);
+ int num = GetStringCount(output);
+ std::vector<string> result(num);
+ for (int i = 0; i < num; i++) {
+ auto ref = GetString(output, i);
+ result[i] = string(ref.str, ref.len);
+ }
+ return result;
+ }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(NormalizeOpTest, RegularInput) {
+ NormalizeOpModel m("I'm good; you're welcome");
+ m.Invoke();
+ EXPECT_THAT(m.GetStringOutput(),
+ ElementsAreArray({"<S> i am good; you are welcome <E>"}));
+}
+
+TEST(NormalizeOpTest, OneInput) {
+ NormalizeOpModel m("Hi!!!!");
+ m.Invoke();
+ EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"<S> hi ! <E>"}));
+}
+
+TEST(NormalizeOpTest, EmptyInput) {
+ NormalizeOpModel m("");
+ m.Invoke();
+ EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"<S> <E>"}));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/predict.cc b/tensorflow/contrib/lite/models/smartreply/ops/predict.cc
new file mode 100644
index 0000000000..7b23adb990
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/predict.cc
@@ -0,0 +1,174 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Lookup projected hash signatures in Predictor model,
+// output predicted labels and weights in decreasing order.
+//
+// Input:
+// Input[0]: A list of hash signatures. int32[num of input]
+// Input[1]: Hash signature keys in the model. int32[keys of model]
+// Input[2]: Labels in the model. int32[keys of model, item per entry]
+// Input[3]: Weights in the model. float[keys of model, item per entry]
+//
+// Output:
+// Output[0]: Predicted labels. int32[num of output]
+// Output[1]: Predicted weights. float[num of output]
+//
+
+#include <algorithm>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace predict {
+
+struct PredictOption {
+ int32_t num_output;
+ float weight_threshold;
+
+ static PredictOption* Cast(void* ptr) {
+ return reinterpret_cast<PredictOption*>(ptr);
+ }
+};
+
+bool WeightGreater(const std::pair<int32_t, float>& a,
+ const std::pair<int32_t, float>& b) {
+ return a.second > b.second;
+}
+
+void* Init(TfLiteContext* context, const char* custom_option, size_t length) {
+ if (custom_option == nullptr || length != sizeof(PredictOption)) {
+ fprintf(stderr, "No Custom option set\n");
+ exit(1);
+ }
+ PredictOption* option = new PredictOption;
+ int offset = 0;
+ option->num_output =
+ *reinterpret_cast<const int32_t*>(custom_option + offset);
+ offset += sizeof(int32_t);
+ option->weight_threshold =
+ *reinterpret_cast<const float*>(custom_option + offset);
+ return reinterpret_cast<void*>(option);
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete PredictOption::Cast(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+
+ TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]];
+ TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]];
+ TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]];
+ TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, model_key->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, model_label->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, model_weight->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, lookup->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, model_key->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, model_label->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, model_weight->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, model_key->dims->data[0],
+ model_label->dims->data[0]);
+ TF_LITE_ENSURE_EQ(context, model_key->dims->data[0],
+ model_weight->dims->data[0]);
+ TF_LITE_ENSURE_EQ(context, model_label->dims->data[1],
+ model_weight->dims->data[1]);
+
+ PredictOption* option = PredictOption::Cast(node->user_data);
+ TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]];
+ TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]];
+ TF_LITE_ENSURE_EQ(context, output_label->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, output_weight->type, kTfLiteFloat32);
+
+ TfLiteIntArray* label_size = TfLiteIntArrayCreate(1);
+ label_size->data[0] = option->num_output;
+ TfLiteIntArray* weight_size = TfLiteIntArrayCreate(1);
+ weight_size->data[0] = option->num_output;
+ TfLiteStatus status =
+ context->ResizeTensor(context, output_label, label_size);
+ if (status != kTfLiteOk) {
+ return status;
+ }
+ return context->ResizeTensor(context, output_weight, weight_size);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]];
+ TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]];
+ TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]];
+
+ // Aggregate by key
+ std::unordered_map<int32_t, float> aggregation;
+ const int num_input = lookup->dims->data[0];
+ const int num_rows = model_key->dims->data[0];
+ const int items = model_label->dims->data[1];
+ int* model_key_end = model_key->data.i32 + num_rows;
+
+ for (int i = 0; i < num_input; i++) {
+ int* ptr = std::lower_bound(model_key->data.i32, model_key_end,
+ lookup->data.i32[i]);
+ if (ptr != nullptr && ptr != model_key_end && *ptr == lookup->data.i32[i]) {
+ int idx = ptr - model_key->data.i32;
+ for (int j = 0; j < items; j++) {
+ aggregation[model_label->data.i32[idx * items + j]] +=
+ model_weight->data.f[idx * items + j] / num_input;
+ }
+ }
+ }
+
+ // Sort by value
+ std::vector<std::pair<int32_t, float>> sorted_labels(aggregation.begin(),
+ aggregation.end());
+ std::sort(sorted_labels.begin(), sorted_labels.end(), WeightGreater);
+
+ PredictOption* option = PredictOption::Cast(node->user_data);
+ TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]];
+ TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]];
+ for (int i = 0; i < output_label->dims->data[0]; i++) {
+ if (i >= sorted_labels.size() ||
+ sorted_labels[i].second < option->weight_threshold) {
+ // Set -1 to avoid lookup message with id 0, which is set for backoff.
+ output_label->data.i32[i] = -1;
+ output_weight->data.f[i] = 0.0f;
+ } else {
+ output_label->data.i32[i] = sorted_labels[i].first;
+ output_weight->data.f[i] = sorted_labels[i].second;
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace predict
+
+TfLiteRegistration* Register_PREDICT() {
+ static TfLiteRegistration r = {predict::Init, predict::Free, predict::Prepare,
+ predict::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc
new file mode 100644
index 0000000000..e97c58cbd1
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc
@@ -0,0 +1,183 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_PREDICT();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class PredictOpModel : public SingleOpModel {
+ public:
+ PredictOpModel(std::initializer_list<int> input_signature_shape,
+ std::initializer_list<int> key_shape,
+ std::initializer_list<int> labelweight_shape, int num_output,
+ float threshold) {
+ input_signature_ = AddInput(TensorType_INT32);
+ model_key_ = AddInput(TensorType_INT32);
+ model_label_ = AddInput(TensorType_INT32);
+ model_weight_ = AddInput(TensorType_FLOAT32);
+ output_label_ = AddOutput(TensorType_INT32);
+ output_weight_ = AddOutput(TensorType_FLOAT32);
+
+ std::vector<uint8_t> predict_option;
+ writeInt32(num_output, &predict_option);
+ writeFloat32(threshold, &predict_option);
+ SetCustomOp("Predict", predict_option, Register_PREDICT);
+ BuildInterpreter({{input_signature_shape, key_shape, labelweight_shape,
+ labelweight_shape}});
+ }
+
+ void SetInputSignature(std::initializer_list<int> data) {
+ PopulateTensor<int>(input_signature_, data);
+ }
+
+ void SetModelKey(std::initializer_list<int> data) {
+ PopulateTensor<int>(model_key_, data);
+ }
+
+ void SetModelLabel(std::initializer_list<int> data) {
+ PopulateTensor<int>(model_label_, data);
+ }
+
+ void SetModelWeight(std::initializer_list<float> data) {
+ PopulateTensor<float>(model_weight_, data);
+ }
+
+ std::vector<int> GetLabel() { return ExtractVector<int>(output_label_); }
+ std::vector<float> GetWeight() {
+ return ExtractVector<float>(output_weight_);
+ }
+
+ void writeFloat32(float value, std::vector<uint8_t>* data) {
+ union {
+ float v;
+ uint8_t r[4];
+ } float_to_raw;
+ float_to_raw.v = value;
+ for (unsigned char i : float_to_raw.r) {
+ data->push_back(i);
+ }
+ }
+
+ void writeInt32(int32_t value, std::vector<uint8_t>* data) {
+ union {
+ int32_t v;
+ uint8_t r[4];
+ } int32_to_raw;
+ int32_to_raw.v = value;
+ for (unsigned char i : int32_to_raw.r) {
+ data->push_back(i);
+ }
+ }
+
+ private:
+ int input_signature_;
+ int model_key_;
+ int model_label_;
+ int model_weight_;
+ int output_label_;
+ int output_weight_;
+};
+
+TEST(PredictOpTest, AllLabelsAreValid) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
+ m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, 11}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0.05})));
+}
+
+TEST(PredictOpTest, MoreLabelsThanRequired) {
+ PredictOpModel m({4}, {5}, {5, 2}, 1, 0.0001);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
+ m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({12}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1})));
+}
+
+TEST(PredictOpTest, OneLabelDoesNotPassThreshold) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.07);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
+ m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, -1}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0})));
+}
+
+TEST(PredictOpTest, NoneLabelPassThreshold) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.6);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
+ m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0})));
+}
+
+TEST(PredictOpTest, OnlyOneLabelGenerated) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 11, 0});
+ m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({11, -1}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.05, 0})));
+}
+
+TEST(PredictOpTest, NoLabelGenerated) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
+ m.SetInputSignature({5, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 0, 0});
+ m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0})));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.cc b/tensorflow/contrib/lite/models/smartreply/predictor.cc
new file mode 100644
index 0000000000..a28222213e
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/predictor.cc
@@ -0,0 +1,116 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/models/smartreply/predictor.h"
+
+#include "absl/strings/str_split.h"
+#include "re2/re2.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
+
+void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
+
+namespace tflite {
+namespace custom {
+namespace smartreply {
+
+// Split sentence into segments (using punctuation).
+std::vector<string> SplitSentence(const string& input) {
+ string result(input);
+
+ RE2::GlobalReplace(&result, "([?.!,])+", " \\1");
+ RE2::GlobalReplace(&result, "([?.!,])+\\s+", "\\1\t");
+ RE2::GlobalReplace(&result, "[ ]+", " ");
+ RE2::GlobalReplace(&result, "\t+$", "");
+
+ return strings::Split(result, '\t');
+}
+
+// Predict with TfLite model.
+void ExecuteTfLite(const string& sentence, ::tflite::Interpreter* interpreter,
+ std::map<string, float>* response_map) {
+ {
+ TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
+ tflite::DynamicBuffer buf;
+ buf.AddString(sentence.data(), sentence.length());
+ buf.WriteToTensor(input);
+ interpreter->AllocateTensors();
+
+ interpreter->Invoke();
+
+ TfLiteTensor* messages = interpreter->tensor(interpreter->outputs()[0]);
+ TfLiteTensor* confidence = interpreter->tensor(interpreter->outputs()[1]);
+
+ for (int i = 0; i < confidence->dims->data[0]; i++) {
+ float weight = confidence->data.f[i];
+ auto response_text = tflite::GetString(messages, i);
+ if (response_text.len > 0) {
+ (*response_map)[string(response_text.str, response_text.len)] += weight;
+ }
+ }
+ }
+}
+
+void GetSegmentPredictions(
+ const std::vector<string>& input, const ::tflite::FlatBufferModel& model,
+ const SmartReplyConfig& config,
+ std::vector<PredictorResponse>* predictor_responses) {
+ // Initialize interpreter
+ std::unique_ptr<::tflite::Interpreter> interpreter;
+ ::tflite::MutableOpResolver resolver;
+ RegisterSelectedOps(&resolver);
+ ::tflite::InterpreterBuilder(model, resolver)(&interpreter);
+
+ if (!model.initialized()) {
+ fprintf(stderr, "Failed to mmap model \n");
+ return;
+ }
+
+ // Execute Tflite Model
+ std::map<string, float> response_map;
+ std::vector<string> sentences;
+ for (const string& str : input) {
+ std::vector<string> splitted_str = SplitSentence(str);
+ sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end());
+ }
+ for (const auto& sentence : sentences) {
+ ExecuteTfLite(sentence, interpreter.get(), &response_map);
+ }
+
+ // Generate the result.
+ for (const auto& iter : response_map) {
+ PredictorResponse prediction(iter.first, iter.second);
+ predictor_responses->emplace_back(prediction);
+ }
+ std::sort(predictor_responses->begin(), predictor_responses->end(),
+ [](const PredictorResponse& a, const PredictorResponse& b) {
+ return a.GetScore() > b.GetScore();
+ });
+
+ // Add backoff response.
+ for (const string& backoff : config.backoff_responses) {
+ if (predictor_responses->size() >= config.num_response) {
+ break;
+ }
+ predictor_responses->push_back({backoff, config.backoff_confidence});
+ }
+}
+
+} // namespace smartreply
+} // namespace custom
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/contrib/lite/models/smartreply/predictor.h
new file mode 100644
index 0000000000..3b9a2b32e1
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/predictor.h
@@ -0,0 +1,80 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace custom {
+namespace smartreply {
+
+const int kDefaultNumResponse = 10;
+const float kDefaultBackoffConfidence = 1e-4;
+
+class PredictorResponse;
+struct SmartReplyConfig;
+
+// With a given string as input, predict the response with a Tflite model.
+// When config.backoff_response is not empty, predictor_responses will be filled
+// with messagees from backoff response.
+void GetSegmentPredictions(const std::vector<string>& input,
+ const ::tflite::FlatBufferModel& model,
+ const SmartReplyConfig& config,
+ std::vector<PredictorResponse>* predictor_responses);
+
+// Data object used to hold a single predictor response.
+// It includes messages, and confidence.
+class PredictorResponse {
+ public:
+ PredictorResponse(const string& response_text, float score) {
+ response_text_ = response_text;
+ prediction_score_ = score;
+ }
+
+ // Accessor methods.
+ const string& GetText() const { return response_text_; }
+ float GetScore() const { return prediction_score_; }
+
+ private:
+ string response_text_ = "";
+ float prediction_score_ = 0.0;
+};
+
+// Configurations for SmartReply.
+struct SmartReplyConfig {
+ // Maximum responses to return.
+ int num_response;
+ // Default confidence for backoff responses.
+ float backoff_confidence;
+ // Backoff responses are used when predicted responses cannot fulfill the
+ // list.
+ const std::vector<string>& backoff_responses;
+
+ SmartReplyConfig(std::vector<string> backoff_responses)
+ : num_response(kDefaultNumResponse),
+ backoff_confidence(kDefaultBackoffConfidence),
+ backoff_responses(backoff_responses) {}
+};
+
+} // namespace smartreply
+} // namespace custom
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
new file mode 100644
index 0000000000..2fa9923bc9
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
@@ -0,0 +1,150 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/models/smartreply/predictor.h"
+
+#include <fstream>
+#include <unordered_set>
+
+#include "base/logging.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+
+namespace tflite {
+namespace custom {
+namespace smartreply {
+namespace {
+
+const char kModelName[] = "smartreply_ondevice_model.bin";
+const char kSamples[] = "smartreply_samples.tsv";
+
+MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") {
+ bool has_expected_response = false;
+ for (const auto &item : *arg) {
+ const string &response = item.GetText();
+ if (expected_response.find(response) != expected_response.end()) {
+ has_expected_response = true;
+ break;
+ }
+ }
+ return has_expected_response;
+}
+
+class PredictorTest : public ::testing::Test {
+ protected:
+ PredictorTest() {
+ model_ = tflite::FlatBufferModel::BuildFromFile(
+ StrCat(TestDataPath(), "/", kModelName).c_str());
+ CHECK(model_);
+ }
+ ~PredictorTest() override {}
+
+ std::unique_ptr<::tflite::FlatBufferModel> model_;
+};
+
+TEST_F(PredictorTest, GetSegmentPredictions) {
+ std::vector<PredictorResponse> predictions;
+
+ GetSegmentPredictions({"Welcome"}, *model_, /*config=*/{{}}, &predictions);
+ EXPECT_GT(predictions.size(), 0);
+
+ float max = 0;
+ for (const auto &item : predictions) {
+ LOG(INFO) << "Response: " << item.GetText();
+ if (item.GetScore() > max) {
+ max = item.GetScore();
+ }
+ }
+
+ EXPECT_GT(max, 0.3);
+ EXPECT_THAT(
+ &predictions,
+ IncludeAnyResponesIn(std::unordered_set<string>({"Thanks very much"})));
+}
+
+TEST_F(PredictorTest, TestTwoSentences) {
+ std::vector<PredictorResponse> predictions;
+
+ GetSegmentPredictions({"Hello", "How are you?"}, *model_, /*config=*/{{}},
+ &predictions);
+ EXPECT_GT(predictions.size(), 0);
+
+ float max = 0;
+ for (const auto &item : predictions) {
+ LOG(INFO) << "Response: " << item.GetText();
+ if (item.GetScore() > max) {
+ max = item.GetScore();
+ }
+ }
+
+ EXPECT_GT(max, 0.3);
+ EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set<string>(
+ {"Hi, how are you doing?"})));
+}
+
+TEST_F(PredictorTest, TestBackoff) {
+ std::vector<PredictorResponse> predictions;
+
+ GetSegmentPredictions({"你好"}, *model_, /*config=*/{{}}, &predictions);
+ EXPECT_EQ(predictions.size(), 0);
+
+ // Backoff responses are returned in order.
+ GetSegmentPredictions({"你好"}, *model_, /*config=*/{{"Yes", "Ok"}},
+ &predictions);
+ EXPECT_EQ(predictions.size(), 2);
+ EXPECT_EQ(predictions[0].GetText(), "Yes");
+ EXPECT_EQ(predictions[1].GetText(), "Ok");
+}
+
+TEST_F(PredictorTest, BatchTest) {
+ int total_items = 0;
+ int total_responses = 0;
+ int total_triggers = 0;
+
+ string line;
+ std::ifstream fin(StrCat(TestDataPath(), "/", kSamples));
+ while (std::getline(fin, line)) {
+ const std::vector<string> &fields = strings::Split(line, '\t');
+ if (fields.empty()) {
+ continue;
+ }
+
+ // Parse sample file and predict
+ const string &msg = fields[0];
+ std::vector<PredictorResponse> predictions;
+ GetSegmentPredictions({msg}, *model_, /*config=*/{{}}, &predictions);
+
+ // Validate response and generate stats.
+ total_items++;
+ total_responses += predictions.size();
+ if (!predictions.empty()) {
+ total_triggers++;
+ }
+ EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set<string>(
+ fields.begin() + 1, fields.end())));
+ }
+
+ LOG(INFO) << "Responses: " << total_responses << " / " << total_items;
+ LOG(INFO) << "Triggers: " << total_triggers << " / " << total_items;
+ EXPECT_EQ(total_triggers, total_items);
+}
+
+} // namespace
+} // namespace smartreply
+} // namespace custom
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_hotword_model_test.cc b/tensorflow/contrib/lite/models/speech_hotword_model_test.cc
new file mode 100644
index 0000000000..f5d1f436bc
--- /dev/null
+++ b/tensorflow/contrib/lite/models/speech_hotword_model_test.cc
@@ -0,0 +1,115 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for speech Hotword model using TFLite Ops.
+
+#include <string.h>
+
+#include <memory>
+#include <string>
+
+#include "base/logging.h"
+#include "file/base/path.h"
+#include "testing/base/public/googletest.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+
+namespace tflite {
+namespace models {
+
+void RunTest(int model_input_tensor, int svdf_layer_state_tensor,
+ int model_output_tensor, const string& model_name,
+ const string& golden_in_name, const string& golden_out_name) {
+ // Read the model.
+ string tflite_file_path = file::JoinPath(TestDataPath(), model_name);
+ auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
+ CHECK(model) << "Failed to read model from file " << tflite_file_path;
+
+ // Initialize the interpreter.
+ ops::builtin::BuiltinOpResolver builtins;
+ std::unique_ptr<Interpreter> interpreter;
+ InterpreterBuilder(*model, builtins)(&interpreter);
+ CHECK(interpreter != nullptr);
+ interpreter->AllocateTensors();
+
+ // Reset the SVDF layer state.
+ memset(interpreter->tensor(svdf_layer_state_tensor)->data.raw, 0,
+ interpreter->tensor(svdf_layer_state_tensor)->bytes);
+
+ // Load the input frames.
+ Frames input_frames;
+ const string input_file_path = file::JoinPath(TestDataPath(), golden_in_name);
+ ReadFrames(input_file_path, &input_frames);
+
+ // Load the golden output results.
+ Frames output_frames;
+ const string output_file_path =
+ file::JoinPath(TestDataPath(), golden_out_name);
+ ReadFrames(output_file_path, &output_frames);
+
+ const int speech_batch_size =
+ interpreter->tensor(model_input_tensor)->dims->data[0];
+ const int speech_input_size =
+ interpreter->tensor(model_input_tensor)->dims->data[1];
+ const int speech_output_size =
+ interpreter->tensor(model_output_tensor)->dims->data[1];
+ const int input_sequence_size =
+ input_frames[0].size() / (speech_input_size * speech_batch_size);
+ float* input_ptr = interpreter->tensor(model_input_tensor)->data.f;
+ float* output_ptr = interpreter->tensor(model_output_tensor)->data.f;
+
+ // The first layer (SVDF) input size is 40 (speech_input_size). Each speech
+ // input frames for this model is 1280 floats, which can be fed to input in a
+ // sequence of size 32 (input_sequence_size).
+ for (int i = 0; i < TestInputSize(input_frames); i++) {
+ int frame_ptr = 0;
+ for (int s = 0; s < input_sequence_size; s++) {
+ for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
+ input_ptr[k] = input_frames[i][frame_ptr++];
+ }
+ interpreter->Invoke();
+ }
+ // After the whole frame (1280 floats) is fed, we can check the output frame
+ // matches with the golden output frame.
+ for (int k = 0; k < speech_output_size; k++) {
+ ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5);
+ }
+ }
+}
+
+TEST(SpeechHotword, OkGoogleTestRank1) {
+ constexpr int kModelInputTensor = 0;
+ constexpr int kSvdfLayerStateTensor = 4;
+ constexpr int kModelOutputTensor = 18;
+
+ RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor,
+ "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv",
+ "speech_hotword_model_out_rank1.csv");
+}
+
+TEST(SpeechHotword, OkGoogleTestRank2) {
+ constexpr int kModelInputTensor = 17;
+ constexpr int kSvdfLayerStateTensor = 1;
+ constexpr int kModelOutputTensor = 18;
+ RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor,
+ "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv",
+ "speech_hotword_model_out_rank2.csv");
+}
+
+} // namespace models
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc b/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc
new file mode 100644
index 0000000000..687cfab0b2
--- /dev/null
+++ b/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc
@@ -0,0 +1,114 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for speech SpeakerId model using TFLite Ops.
+
+#include <string.h>
+
+#include <memory>
+#include <string>
+
+#include "base/logging.h"
+#include "file/base/path.h"
+#include "testing/base/public/googletest.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
+
+void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
+
+namespace tflite {
+namespace models {
+
+constexpr int kModelInputTensor = 0;
+constexpr int kLstmLayer1OutputStateTensor = 19;
+constexpr int kLstmLayer1CellStateTensor = 20;
+constexpr int kLstmLayer2OutputStateTensor = 40;
+constexpr int kLstmLayer2CellStateTensor = 41;
+constexpr int kLstmLayer3OutputStateTensor = 61;
+constexpr int kLstmLayer3CellStateTensor = 62;
+constexpr int kModelOutputTensor = 66;
+
+TEST(SpeechSpeakerId, OkGoogleTest) {
+ // Read the model.
+ string tflite_file_path =
+ file::JoinPath(TestDataPath(), "speech_speakerid_model.tflite");
+ auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
+ CHECK(model) << "Failed to read model from file " << tflite_file_path;
+
+ // Initialize the interpreter.
+ ::tflite::MutableOpResolver resolver;
+ RegisterSelectedOps(&resolver);
+ std::unique_ptr<Interpreter> interpreter;
+ InterpreterBuilder(*model, resolver)(&interpreter);
+ CHECK(interpreter != nullptr);
+ interpreter->AllocateTensors();
+
+ // Load the input frames.
+ Frames input_frames;
+ const string input_file_path =
+ file::JoinPath(TestDataPath(), "speech_speakerid_model_in.csv");
+ ReadFrames(input_file_path, &input_frames);
+
+ // Load the golden output results.
+ Frames output_frames;
+ const string output_file_path =
+ file::JoinPath(TestDataPath(), "speech_speakerid_model_out.csv");
+ ReadFrames(output_file_path, &output_frames);
+
+ const int speech_batch_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[0];
+ const int speech_input_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[1];
+ const int speech_output_size =
+ interpreter->tensor(kModelOutputTensor)->dims->data[1];
+
+ float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f;
+ float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f;
+
+ // Clear the LSTM state for layers.
+ memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3CellStateTensor)->bytes);
+ for (int i = 0; i < input_frames.size(); i++) {
+ // Feed the input to model.
+ int frame_ptr = 0;
+ for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
+ input_ptr[k] = input_frames[i][frame_ptr++];
+ }
+ // Run the model.
+ interpreter->Invoke();
+ // Validate the output.
+ for (int k = 0; k < speech_output_size; k++) {
+ ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5);
+ }
+ }
+}
+
+} // namespace models
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc b/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc
new file mode 100644
index 0000000000..30d89a1354
--- /dev/null
+++ b/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc
@@ -0,0 +1,127 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for speech TERSE AM model using TFLite Ops.
+
+#include <string.h>
+
+#include <memory>
+#include <string>
+
+#include "base/logging.h"
+#include "file/base/path.h"
+#include "testing/base/public/googletest.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+
+namespace tflite {
+namespace models {
+
+constexpr int kModelInputTensor = 0;
+constexpr int kLstmLayer1OutputStateTensor = 19;
+constexpr int kLstmLayer1CellStateTensor = 20;
+constexpr int kLstmLayer2OutputStateTensor = 40;
+constexpr int kLstmLayer2CellStateTensor = 41;
+constexpr int kLstmLayer3OutputStateTensor = 61;
+constexpr int kLstmLayer3CellStateTensor = 62;
+constexpr int kLstmLayer4OutputStateTensor = 82;
+constexpr int kLstmLayer4CellStateTensor = 83;
+constexpr int kLstmLayer5OutputStateTensor = 103;
+constexpr int kLstmLayer5CellStateTensor = 104;
+constexpr int kModelOutputTensor = 109;
+
+TEST(SpeechTerseAm, RandomIOTest) {
+ // Read the model.
+ string tflite_file_path =
+ file::JoinPath(TestDataPath(), "speech_terse_am_model.tflite");
+ auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
+ CHECK(model) << "Failed to mmap model " << tflite_file_path;
+
+ // Initialize the interpreter.
+ ops::builtin::BuiltinOpResolver builtins;
+ std::unique_ptr<Interpreter> interpreter;
+ InterpreterBuilder(*model, builtins)(&interpreter);
+ CHECK(interpreter != nullptr);
+ interpreter->AllocateTensors();
+
+ // Load the input frames.
+ Frames input_frames;
+ const string input_file_path =
+ file::JoinPath(TestDataPath(), "speech_terse_am_model_in.csv");
+ ReadFrames(input_file_path, &input_frames);
+
+ // Load the golden output results.
+ Frames output_frames;
+ const string output_file_path =
+ file::JoinPath(TestDataPath(), "speech_terse_am_model_out.csv");
+ ReadFrames(output_file_path, &output_frames);
+
+ const int speech_batch_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[0];
+ const int speech_input_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[1];
+ const int speech_output_size =
+ interpreter->tensor(kModelOutputTensor)->dims->data[1];
+
+ float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f;
+ float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f;
+
+ // Clear the LSTM state for layers.
+ memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer4OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer4OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer4CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer4CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer5OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer5OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer5CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer5CellStateTensor)->bytes);
+
+
+ for (int i = 0; i < input_frames.size(); i++) {
+ // Feed the input to model.
+ int frame_ptr = 0;
+ for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
+ input_ptr[k] = input_frames[i][frame_ptr++];
+ }
+ // Run the model.
+ interpreter->Invoke();
+ // Validate the output.
+ for (int k = 0; k < speech_output_size; k++) {
+ ASSERT_NEAR(output_ptr[k], output_frames[i][k], 5.2e-4);
+ }
+ }
+}
+
+} // namespace models
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_tts_model_test.cc b/tensorflow/contrib/lite/models/speech_tts_model_test.cc
new file mode 100644
index 0000000000..e6f2673a42
--- /dev/null
+++ b/tensorflow/contrib/lite/models/speech_tts_model_test.cc
@@ -0,0 +1,116 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for speech TTS model using TFLite Ops.
+
+#include <string.h>
+
+#include <memory>
+#include <string>
+
+#include "base/logging.h"
+#include "file/base/path.h"
+#include "testing/base/public/googletest.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+
+namespace tflite {
+namespace models {
+
+constexpr int kModelInputTensor = 0;
+constexpr int kLstmLayer1OutputStateTensor = 25;
+constexpr int kLstmLayer1CellStateTensor = 26;
+constexpr int kLstmLayer2OutputStateTensor = 46;
+constexpr int kLstmLayer2CellStateTensor = 47;
+constexpr int kLstmLayer3OutputStateTensor = 67;
+constexpr int kLstmLayer3CellStateTensor = 68;
+constexpr int kRnnLayerHiddenStateTensor = 73;
+constexpr int kModelOutputTensor = 74;
+
+TEST(SpeechTTS, RandomIOTest) {
+ // Read the model.
+ string tflite_file_path =
+ file::JoinPath(TestDataPath(), "speech_tts_model.tflite");
+ auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
+ CHECK(model) << "Failed to mmap model " << tflite_file_path;
+
+ // Initialize the interpreter.
+ ops::builtin::BuiltinOpResolver builtins;
+ std::unique_ptr<Interpreter> interpreter;
+ InterpreterBuilder(*model, builtins)(&interpreter);
+ CHECK(interpreter != nullptr);
+ interpreter->AllocateTensors();
+
+ // Load the input frames.
+ Frames input_frames;
+ const string input_file_path =
+ file::JoinPath(TestDataPath(), "speech_tts_model_in.csv");
+ ReadFrames(input_file_path, &input_frames);
+
+ // Load the golden output results.
+ Frames output_frames;
+ const string output_file_path =
+ file::JoinPath(TestDataPath(), "speech_tts_model_out.csv");
+ ReadFrames(output_file_path, &output_frames);
+
+ const int speech_batch_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[0];
+ const int speech_input_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[1];
+ const int speech_output_size =
+ interpreter->tensor(kModelOutputTensor)->dims->data[1];
+
+ float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f;
+ float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f;
+
+ // Clear the LSTM state for layers.
+ memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kRnnLayerHiddenStateTensor)->data.raw, 0,
+ interpreter->tensor(kRnnLayerHiddenStateTensor)->bytes);
+
+ for (int i = 0; i < input_frames.size(); i++) {
+ // Feed the input to model.
+ int frame_ptr = 0;
+ for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
+ input_ptr[k] = input_frames[i][frame_ptr++];
+ }
+ // Run the model.
+ interpreter->Invoke();
+ // Validate the output.
+ for (int k = 0; k < speech_output_size; k++) {
+ ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5);
+ }
+ }
+}
+
+} // namespace models
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/test_utils.h b/tensorflow/contrib/lite/models/test_utils.h
new file mode 100644
index 0000000000..b2596babd0
--- /dev/null
+++ b/tensorflow/contrib/lite/models/test_utils.h
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_
+
+#include <stdlib.h>
+#include <string.h>
+
+#include <fstream>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace tflite {
+namespace models {
+using Frames = std::vector<std::vector<float>>;
+} // namespace models
+} // namespace tflite
+
+#ifndef __ANDROID__
+#include "file/base/path.h"
+#include "tensorflow/core/platform/test.h"
+
+inline string TestDataPath() {
+ return string(file::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
+ "contrib/lite/models/testdata/"));
+}
+inline int TestInputSize(const tflite::models::Frames& input_frames) {
+ return input_frames.size();
+}
+#else
+inline string TestDataPath() {
+ return string("third_party/tensorflow/contrib/lite/models/testdata/");
+}
+
+inline int TestInputSize(const tflite::models::Frames& input_frames) {
+ // Android TAP is very slow, we only test the first 20 frames.
+ return 20;
+}
+#endif
+
+namespace tflite {
+namespace models {
+
+// Read float data from a comma-separated file:
+// Each line will be read into a float vector.
+// The return result will be a vector of float vectors.
+void ReadFrames(const string& csv_file_path, Frames* frames) {
+ std::ifstream csv_file(csv_file_path);
+ string line;
+ while (std::getline(csv_file, line, '\n')) {
+ std::vector<float> fields;
+ // Used by strtok_r internaly for successive calls on the same string.
+ char* save_ptr = nullptr;
+
+ // Tokenize the line.
+ char* next_token =
+ strtok_r(const_cast<char*>(line.c_str()), ",", &save_ptr);
+ while (next_token != nullptr) {
+ float f = strtod(next_token, nullptr);
+ fields.push_back(f);
+ next_token = strtok_r(nullptr, ",", &save_ptr);
+ }
+ frames->push_back(fields);
+ }
+ csv_file.close();
+}
+
+} // namespace models
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_
diff --git a/tensorflow/contrib/lite/nnapi/BUILD b/tensorflow/contrib/lite/nnapi/BUILD
new file mode 100644
index 0000000000..402f1e949b
--- /dev/null
+++ b/tensorflow/contrib/lite/nnapi/BUILD
@@ -0,0 +1,25 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = [
+ "//visibility:public",
+])
+
+cc_library(
+ name = "nnapi_lib",
+ hdrs = [
+ "NeuralNetworksShim.h",
+ ],
+ linkopts = ["-ldl"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
new file mode 100644
index 0000000000..5d06165772
--- /dev/null
+++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
@@ -0,0 +1,1916 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef NN_API_SHIM_H0
+#define NN_API_SHIM_H0
+
+#include <dlfcn.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+// helpers
+
+#define NNAPI_LOG(format, ...) printf(format "\n", __VA_ARGS__);
+#define LOAD_FUNCTION(name) \
+ static name##_fn fn = reinterpret_cast<name##_fn>(loadFunction(#name));
+#define EXECUTE_FUNCTION(...) \
+ if (fn != nullptr) { \
+ fn(__VA_ARGS__); \
+ }
+#define EXECUTE_FUNCTION_RETURN(...) return fn != nullptr ? fn(__VA_ARGS__) : 0;
+
+inline void* loadLibrary(const char* name) {
+ // TODO: change RTLD_LOCAL? Assumes there can be multiple instances of nn
+ // api RT
+ void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
+ if (handle == nullptr) {
+ NNAPI_LOG("nnapi error: unable to open library %s", name);
+ }
+ return handle;
+}
+
+inline void* getLibraryHandle() {
+ static void* handle = loadLibrary("libneuralnetworks.so");
+ return handle;
+}
+
+inline void* loadFunction(const char* name) {
+ void* fn = nullptr;
+ if (getLibraryHandle() != nullptr) {
+ fn = dlsym(getLibraryHandle(), name);
+ }
+ if (fn == nullptr) {
+ NNAPI_LOG("nnapi error: unable to open function %s", name);
+ }
+ return fn;
+}
+
+inline bool NNAPIExists() {
+ static bool nnapi_is_available = getLibraryHandle();
+ return nnapi_is_available;
+}
+
+// nn api types
+
+/**
+ * Operand types.
+ *
+ * The type of operands that can be added to a model.
+ *
+ * Although we define many types, most operators accept just a few
+ * types. Most used are ANEURALNETWORKS_TENSOR_FLOAT32,
+ * ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, and ANEURALNETWORKS_INT32.
+ */
+enum {
+ /** The following entries are used to declare scalars. */
+
+ /** A 32 bit floating point scalar value. */
+ ANEURALNETWORKS_FLOAT32 = 0,
+ /** A signed 32 bit integer scalar value. */
+ ANEURALNETWORKS_INT32 = 1,
+ /** An unsigned 32 bit integer scalar value. */
+ ANEURALNETWORKS_UINT32 = 2,
+
+ /** The following entries are used to declare tensors. */
+
+ /** A tensor of 32 bit floating point values. */
+ ANEURALNETWORKS_TENSOR_FLOAT32 = 3,
+ /** A tensor of 32 bit integer values. */
+ ANEURALNETWORKS_TENSOR_INT32 = 4,
+ /** A tensor of 8 bit integers that represent real numbers.
+ *
+ * Attached to this tensor are two numbers that can be used to convert
+ * the 8 bit integer to the real value and vice versa. These two numbers are:
+ * - scale: a 32 bit floating point value
+ * - zero_value: an 32 bit integer
+ *
+ * The formula is:
+ * real_value = (integer_value - zero_value) * scale.
+ */
+ ANEURALNETWORKS_TENSOR_QUANT8_ASYMM = 5,
+};
+
+/**
+ * Operation types.
+ *
+ * The type of operations that can be added to a model.
+ */
+enum {
+ /** Adds two tensors, elment-wise.
+ *
+ * Takes two input tensors of identical type and compatible dimensions. The
+ * output is the sum of both input tensors, optionally modified by an
+ * activation function.
+ *
+ * Two dimensions are compatible when:
+ * 1. they are equal, or
+ * 2. one of them is 1
+ *
+ * The size of the output is the maximum size along each dimension of the
+ * input operands. It starts with the trailing dimensions, and works its way
+ * forward.
+ *
+ * Example:
+ *
+ * input1.dimension = {4, 1, 2}
+ * input2.dimension = {5, 4, 3, 1}
+ * output.dimension = {5, 4, 3, 2}
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: up to 4
+ *
+ * Inputs:
+ * * 0: A tensor.
+ * * 1: A tensor of the same type, and compatible dimensions as input0.
+ * * 2: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The sum, a tensor of the same type as input0.
+ */
+ ANEURALNETWORKS_ADD = 0,
+ /** Performs a 2-D average pooling operation.
+ *
+ * The output dimensions are functions of the filter dimensions, stride, and
+ * padding.
+ *
+ * The values in the output tensor are computed as:
+ *
+ * output[batch, row, col, channel] =
+ * sum_{i, j}(input[batch, row + i, col + j, channel]) / sum(1)
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
+ * input.
+ * * 1: An INT32 value, specifying the padding on the left, in the ‘width’
+ * dimension.
+ * * 2: An INT32 value, specifying the padding on the right,in the ‘width’
+ * dimension.
+ * * 3: An INT32 value, specifying the padding on the top, in the ‘height’
+ * dimension.
+ * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’
+ * dimension.
+ * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension.
+ * * 6: An INT32 value, specifying the output stride in the ‘height’
+ * dimension.
+ * * 7: An INT32 value, specifying the filter width.
+ * * 8: An INT32 value, specifying the filter height.
+ * * 9: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
+ * depth].
+ */
+ ANEURALNETWORKS_AVERAGE_POOL_2D = 1,
+ /** Concatenates the input tensors along the given dimension.
+ *
+ * The input tensors must have identical type and the same dimensions except
+ * the dimension along the concatenation axis.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4
+ *
+ * Inputs:
+ * 0 ~ n: The list on n input tensors, of shape [D0, D1, ..., Daxis(i), ...,
+ * Dm] n+1: An INT32 value, specifying the concatenation axis. n+2: An INT32
+ * value, and has to be one of the {@link FuseCode} values. Specifies the
+ * activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The output, a tensor of the same type as the input tensors.
+ * The output shape is [D0, D1, ..., sum(Daxis(i)), ..., Dm].
+ */
+ ANEURALNETWORKS_CONCATENATION = 2,
+ /** Performs an 2-D convolution operation.
+ *
+ * The CONV_2D op sweeps a 2-D filter that can mix channels together over a
+ * batch of images, applying the filter to each window of each image of the
+ * appropriate size.
+ *
+ * The output dimensions are functions of the filter dimensions, stride, and
+ * padding.
+ *
+ * The values in the output tensor are computed as:
+ *
+ * output[batch, row, col, channel] =
+ * sum_{i, j} (
+ * input[batch, row + i, col + j, k] *
+ * filter[channel, row + i, col + j, k] +
+ * bias[channel]
+ * )
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying
+ * the input.
+ * * 1: A 4-D tensor, of shape [depth_out, filter_height, filter_width,
+ * depth_in], specifying the filter.
+ * * 2: A 1-D tensor, of shape [depth_out], specifying the bias.
+ * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the
+ * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input
+ * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should
+ * be of {@link ANEURALNETWORKS_TENSOR_INT32}.
+ * * 3: An INT32 value, specifying the padding on the left, in the ‘width’
+ * dimension.
+ * * 4: An INT32 value, specifying the padding on the right,in the ‘width’
+ * dimension.
+ * * 5: An INT32 value, specifying the padding on the top, in the ‘height’
+ * dimension.
+ * * 6: An INT32 value, specifying the padding on the bottom, in the ‘height’
+ * dimension.
+ * * 7: An INT32 value, specifying the output stride in the ‘width’ dimension.
+ * * 8: An INT32 value, specifying the output stride in the ‘height’
+ * dimension.
+ * * 9: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
+ * depth_out].
+ */
+ ANEURALNETWORKS_CONV_2D = 3,
+ /** Performs a depthwise 2-D convolution operation.
+ *
+ * Given an input tensor of shape [batches, height, width, depth_in] and a
+ * filter tensor of shape [depth_out, filter_height, filter_width, depth_in]
+ * containing in_channels convolutional filters of depth 1, DEPTHWISE_CONV
+ * applies a different filter to each input channel (expanding from 1 channel
+ * to channel_multiplier channels for each), then concatenates the results
+ * together.
+ *
+ * The output has depth_out = depth_in * depth_multiplier channels.
+ * The output dimensions are functions of the filter dimensions, stride, and
+ * padding.
+ *
+ * The values in the output tensor are computed as:
+ *
+ * output[b, i, j, k * channel_multiplier + q] =
+ * sum_{di, dj} (
+ * input[b, strides[1] * i + di, strides[2] * j + dj, k] *
+ * filter[di, dj, k, q]
+ * )
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying
+ * the input.
+ * * 1: A 4-D tensor, of shape [depth_out, filter_height, filter_width,
+ * depth_in], specifying the filter.
+ * * 2: A 1-D tensor, of shape [depth_out], specifying the bias.
+ * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the
+ * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input
+ * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should
+ * be of {@link ANEURALNETWORKS_TENSOR_INT32}.
+ * * 3: An INT32 value, specifying the padding on the left, in the ‘width’
+ * dimension.
+ * * 4: An INT32 value, specifying the padding on the right,in the ‘width’
+ * dimension.
+ * * 5: An INT32 value, specifying the padding on the top, in the ‘height’
+ * dimension.
+ * * 6: An INT32 value, specifying the padding on the bottom, in the ‘height’
+ * dimension.
+ * * 7: An INT32 value, specifying the output stride in the ‘width’ dimension.
+ * * 8: An INT32 value, specifying the output stride in the ‘height’
+ * dimension.
+ * * 9: An INT32 value, specifying the depthwise multiplier.
+ * * 10: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
+ * depth_out].
+ */
+ ANEURALNETWORKS_DEPTHWISE_CONV_2D = 4,
+ /** Rearranges data from depth into blocks of spatial data.
+ *
+ * More specifically, this op outputs a copy of the input tensor where values
+ * from the depth dimension are moved in spatial blocks to the height and
+ * width dimensions. The value block_size indicates the input block size and
+ * how the data is moved.
+ *
+ * Chunks of data of size block_size * block_size from depth are rearranged
+ * into non-overlapping blocks of size block_size x block_size.
+ *
+ * The width of the output tensor is input_depth * block_size, whereas the
+ * height is input_height * block_size. The depth of the input tensor must be
+ * divisible by block_size * block_size
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying
+ * the input.
+ * * 1: An INT32 value, specifying the block_size. block_size must be >=1 and
+ * block_size * block_size must be a divisor of the input depth.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batch, height*block_size,
+ * width*block_size, depth/(block_size*block_size)].
+ */
+ ANEURALNETWORKS_DEPTH_TO_SPACE = 5,
+ /** Dequantizes the input tensor.
+ *
+ * The formula is:
+ *
+ * output = (input - zero_value) * scale.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4
+ *
+ * Inputs:
+ * * 0: A tensor of type {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0, but with type
+ * {@link ANEURALNETWORKS_TENSOR_FLOAT32}.
+ */
+ ANEURALNETWORKS_DEQUANTIZE = 6,
+
+ /**
+ * Looks up items from a given tensor.
+ *
+ * Each item in the output is a raw copy of the corresponding item in
+ * the input “values”. If the the given “lookup” indices are out of bounds,
+ * the op will fail and an error will be reported.
+ *
+ * Inputs:
+ * * 0: Values. An n-D tensor of any type X (where n >= 2). E.g., if n is 2,
+ * then the shape would be [lookup_dimension, values_dimension], where
+ * “lookup_dimension” corresponds to the indexing dimension in the lookup
+ * table, and “values_dimension” to the contents.
+ * * 1: Lookups. An 1-D tensor of type T, of shape [lookup_size], where
+ * “lookup_size” is the number of elements to look for, and each entry
+ * corresponds to the first dimension of the “values” tensor.
+ *
+ * Output:
+ * * 0: A n-D tensor of type X and the same rank and shape as the “values”
+ * tensor, except for the first dimension which has size “lookup_size”.
+ */
+ ANEURALNETWORKS_EMBEDDING_LOOKUP = 7,
+
+ /** Computes element-wise floor() on the input tensor.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: up to 4
+ *
+ * Inputs:
+ * * 0: A tensor.
+ *
+ * Outputs:
+ * * 0: The output, a tensor of the same type and dimensions as input0.
+ */
+ ANEURALNETWORKS_FLOOR = 8,
+ /** Denotes a fully (densely) connected layer, which connects all elements in
+ * the input tensor with each element in the output tensor.
+ *
+ * This layer implements the operation:
+ *
+ * outputs = activation(inputs * weights’ + bias)
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4.
+ *
+ * Inputs:
+ * * 0: A tensor, specifying the input. If rank is greater than 2, then it
+ * gets flattened to a 2-D Tensor. The 2-D Tensor is handled as if dimensions
+ * corresponded to shape [batch_size, input_size], where “batch_size”
+ * corresponds to the batching dimension, and “input_size” is the size of the
+ * input.
+ * * 1: A 2-D tensor, specifying the weights, of shape [num_units,
+ * input_size], where "num_units" corresponds to the number of output nodes.
+ * * 2: A 1-D tensor, of shape [num_units], specifying the bias.
+ * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the
+ * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input
+ * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should
+ * be of {@link ANEURALNETWORKS_TENSOR_INT32}.
+ * * 3: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The output tensor, of shape [batch_size, num_units].
+ */
+ ANEURALNETWORKS_FULLY_CONNECTED = 9,
+
+ /**
+ * Looks up values of a hash table with given keys.
+ *
+ * Inputs:
+ * * 0: Lookups. A 1-D int32 tensor with shape [ k ].
+ * * 1: Keys. A 1-D int32 tensor with shape [ n ], *MUST* be sorted in
+ * ascending order.
+ * * 2: Values. A tensor with shape [ n … ].
+ *
+ * Outputs:
+ * * 0: Output. A tensor with shape [ k …].
+ * * 1: Hits. A uint8 tensor with shape [ k ] indicates whether the lookup
+ * hits or not.
+ */
+ ANEURALNETWORKS_HASHTABLE_LOOKUP = 10,
+
+ /** Applies L2 normalization along the depth dimension.
+ *
+ * The values in the output tensor are computed as:
+ *
+ * output[batch, row, col, channel] =
+ * input[batch, row, col, channel] /
+ * sqrt(sum_{c} pow(input[batch, row, col, c], 2))
+ *
+ * For x with more dimensions, independently normalizes each 1-D slice along
+ * dimension dim.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
+ * input.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
+ * depth].
+ */
+ ANEURALNETWORKS_L2_NORMALIZATION = 11,
+
+ /** Performs an 2-D L2 pooling operation.
+ *
+ * The output dimensions are functions of the filter dimensions, stride, and
+ * padding.
+ *
+ * The values in the output tensor are computed as:
+ *
+ * output[batch, row, col, channel] =
+ * sqrt(sum_{i, j} pow(input[batch, row + i, col + j, channel], 2) /
+ * sum(1))
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
+ * input.
+ * * 1: An INT32 value, specifying the padding on the left, in the ‘width’
+ * dimension.
+ * * 2: An INT32 value, specifying the padding on the right,in the ‘width’
+ * dimension.
+ * * 3: An INT32 value, specifying the padding on the top, in the ‘height’
+ * dimension.
+ * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’
+ * dimension.
+ * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension.
+ * * 6: An INT32 value, specifying the output stride in the ‘height’
+ * dimension.
+ * * 7: An INT32 value, specifying the filter width.
+ * * 8: An INT32 value, specifying the filter height.
+ * * 9: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
+ * depth].
+ */
+ ANEURALNETWORKS_L2_POOL_2D = 12,
+ /** Applies Local Response Normalization along the depth dimension.
+ *
+ * The 4-D input tensor is treated as a 3-D array of 1-D vectors (along the
+ * last dimension), and each vector is normalized independently. Within a
+ * given vector, each component is divided by the weighted, squared sum of
+ * inputs within depth_radius.
+ *
+ * The output is calculated using this formula:
+ *
+ * sqr_sum[a, b, c, d] =
+ * sum(pow(input[a, b, c, d - depth_radius : d + depth_radius + 1], 2)
+ * output = input / pow((bias + alpha * sqr_sum), beta)
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
+ * input.
+ * * 1: An INT32 value, specifying the radius of the normalization window.
+ * * 2: A FLOAT32 value, specifying the bias, must not be zero.
+ * * 3: A FLOAT32 value, specifying the scale factor, alpha.
+ * * 4: A FLOAT32 value, specifying the exponent, beta.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0.
+ */
+ ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION = 13,
+ /** Computes sigmoid activation on the input tensor element-wise.
+ *
+ * The output is calculated using this formula:
+ *
+ * output = 1 / (1 + exp(-input))
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4.
+ *
+ * Inputs:
+ * * 0: A tensor, specifying the input.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0.
+ */
+ ANEURALNETWORKS_LOGISTIC = 14,
+
+ /**
+ * Projects an input to a bit vector via locality senstive hashing.
+ *
+ * Inputs:
+ * * 0: Hash functions. Dim.size == 2, DataType: Float.
+ * Tensor[0].Dim[0]: Number of hash functions.
+ * Tensor[0].Dim[1]: Number of seeds per hash functions.
+ * Tensor[0].Dim[1] <= 32 in sparse case.
+ *
+ * * 1: Input. Dim.size >= 1, no restriction on DataType.
+ * * 2: Weight. Optional. Dim.size == 1, DataType: Float.
+ * If not set, each input element is considered to have the same weight of
+ * 1.0.
+ * Tensor[1].Dim[0] == Tensor[2].Dim[0]
+ * * 3: Type:
+ * Sparse: Value LSHProjectionType_SPARSE(=1).
+ * Computed bit vector is considered to be sparse.
+ * Each output element is an int32 made up of multiple bits computed
+ * from hash functions.
+ *
+ * Dense: Value LSHProjectionType_DENSE(=2).
+ * Computed bit vector is considered to be dense. Each output element
+ * represents a bit and can take the value of either 0 or 1.
+ *
+ * Outputs:
+ * * 0: If the projection type is sparse:
+ * Output.Dim == { Tensor[0].Dim[0] }
+ * A tensor of int32 that represents hash signatures.
+ * If the projection type is Dense:
+ * Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] }
+ * A flattened tensor that represents projected bit vectors.
+ */
+ ANEURALNETWORKS_LSH_PROJECTION = 15,
+
+ /**
+ * Long short-term memory unit (LSTM) recurrent network layer.
+ *
+ * The default non-peephole implementation is based on:
+ * http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
+ * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural
+ * Computation, 9(8):1735-1780, 1997.
+ *
+ * The peephole implementation is based on:
+ * https://research.google.com/pubs/archive/43905.pdf
+ * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory
+ * recurrent neural network architectures for large scale acoustic modeling."
+ * INTERSPEECH, 2014.
+ *
+ * The coupling of input and forget gate (CIFG) is based on:
+ * http://arxiv.org/pdf/1503.04069.pdf
+ * Greff et al. "LSTM: A Search Space Odyssey"
+ *
+ * The class has the following independently optional inputs:
+ * * If input gate (if CIFG): “input_to_forget_weights”,
+ * “recurrent_to_input_weights”, “cell_to_input_weights”, “input_gate_bias”.
+ * * If no peephole connections: “cell_to_input_weights”,
+ * “cell_to_forget_weights”, “cell_to_output_weights”.
+ * * If no projection layer: “projection_weights” and “projection_bias”.
+ * * If no projection bias: “projection_bias”.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Inputs:
+ * * 0: Input.
+ * A 2-D tensor of type T, of shape [batch_size, input_size], where
+ * “batch_size” corresponds to the batching dimension, and “input_size”
+ * is the size of the input.
+ * * 1: input_to_input_weights.
+ * A 2-D tensor of type T, of shape [num_units, input_size], where
+ * “num_units” corresponds to the number of cell units.
+ * * 2: input_to_forget_weights.
+ * A 2-D tensor of type T, of shape [num_units, input_size].
+ * * 3: input_to_cell_weights.
+ * A 2-D tensor of type T, of shape [num_units, input_size].
+ * * 4: input_to_output_weights.
+ * A 2-D tensor of type T, of shape [num_units, input_size].
+ * * 5: recurrent_to_input_weights.
+ * A 2-D tensor of type T, of shape [num_units, output_size], where
+ * “output_size” corresponds to either the number of cell units (i.e.,
+ * “num_units”), or the second dimension of the “projection_weights”, if
+ * defined.
+ * * 6: recurrent_to_forget_weights.
+ * A 2-D tensor of type T, of shape [num_units, output_size].
+ * * 7: recurrent_to_cell_weights.
+ * A 2-D tensor of type T, of shape [num_units, output_size].
+ * * 8: recurrent_to_output_weights.
+ * A 2-D tensor of type T, of shape [num_units, output_size].
+ * * 9: cell_to_input_weights.
+ * A 1-D tensor of type T, of shape [num_units].
+ * * 10:cell_to_forget_weights.
+ * A 1-D tensor of type T, of shape [num_units].
+ * * 11:cell_to_output_weights.
+ * A 1-D tensor of type T, of shape [num_units].
+ * * 12:input_gate_bias.
+ * A 1-D tensor of type T, of shape [num_units].
+ * * 13:forget_gate_bias.
+ * A 1-D tensor of type T, of shape [num_units].
+ * * 14:cell_bias.
+ * A 1-D tensor of type T, of shape [num_units].
+ * * 15:output_gate_bias.
+ * A 1-D tensor of type T, of shape [num_units].
+ * * 16:projection_weights.
+ * A 2-D tensor of type T, of shape [output_size, num_units].
+ * * 17:projection_bias.
+ * A 1-D tensor of type T, of shape [output_size].
+ *
+ * Parameters:
+ * * 18:fused_activation_function.
+ * An (optional) ActivationFunctionType indicating the activation
+ * function.
+ * If “NONE” is specified then it results in a linear activation.
+ * * 19:cell_clip.
+ * A clipping threshold for the cell state, such that values are bound
+ * within [-cell_clip, cell_clip]. If set to 0.0 then clipping is
+ * disabled.
+ * * 20:proj_clip.
+ * A clipping threshold for the output from the projection layer, such
+ * that values are bound within [-proj_clip, proj_clip]. If set to 0.0
+ * then clipping is disabled.
+ *
+ * Outputs:
+ * * 0: scratch_buffer.
+ * A 3-D tensor of type T, of shape [batch_size, num_cell, 4].
+ * * 1: output_state.
+ * A 2-D tensor of type T, of shape [batch_size, output_size].
+ * * 2: cell_state.
+ * A 2-D tensor of type T, of shape [batch_size, num_units].
+ * * 3: output.
+ * A 2-D tensor of type T, of shape [batch_size, output_size]. This is
+ * effectively the same as the current “output_state” value.
+ */
+ ANEURALNETWORKS_LSTM = 16,
+
+ /** Performs an 2-D max pooling operation.
+ *
+ * The output dimensions are functions of the filter dimensions, stride, and
+ * padding.
+ *
+ * The values in the output tensor are computed as:
+ *
+ * output[batch, row, col, channel] =
+ * max_{i, j} (input[batch, row + i, col + j, channel])
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
+ * input.
+ * * 1: An INT32 value, specifying the padding on the left, in the ‘width’
+ * dimension.
+ * * 2: An INT32 value, specifying the padding on the right,in the ‘width’
+ * dimension.
+ * * 3: An INT32 value, specifying the padding on the top, in the ‘height’
+ * dimension.
+ * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’
+ * dimension.
+ * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension.
+ * * 6: An INT32 value, specifying the output stride in the ‘height’
+ * dimension.
+ * * 7: An INT32 value, specifying the filter width.
+ * * 8: An INT32 value, specifying the filter height.
+ * * 9: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
+ * depth].
+ */
+ ANEURALNETWORKS_MAX_POOL_2D = 17,
+
+ /** Multiplies two tensors, elment-wise.
+ *
+ * Takes two input tensors of identical type and compatible dimensions. The
+ * output is the product of both input tensors, optionally modified by an
+ * activation function.
+ *
+ * Two dimensions are compatible when:
+ * 1. they are equal, or
+ * 2. one of them is 1
+ *
+ * The size of the resulting output is the maximum size along each dimension
+ * of the input operands. It starts with the trailing dimensions, and works
+ * its way forward.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: up to 4
+ *
+ * Inputs:
+ * * 0: A tensor.
+ * * 1: A tensor of the same type, and compatible dimensions as input0.
+ * * 2: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The product, a tensor of the same type as input0.
+ */
+ ANEURALNETWORKS_MUL = 18,
+ /** Computes rectified linear activation on the input tensor element-wise.
+ *
+ * The output is calculated using this formula:
+ *
+ * output = max(0, input)
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4.
+ *
+ * Inputs:
+ * * 0: A tensor, specifying the input.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0.
+ */
+ ANEURALNETWORKS_RELU = 19,
+ /** Computes rectified linear 1 activation on the input tensor element-wise.
+ *
+ * The output is calculated using this formula:
+ *
+ * output = min(1.f, max(-1.f, input))
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4.
+ *
+ * Inputs:
+ * * 0: A tensor, specifying the input.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0.
+ */
+ ANEURALNETWORKS_RELU1 = 20,
+ /** Computes rectified linear 6 activation on the input tensor element-wise.
+ *
+ * The output is calculated using this formula:
+ *
+ * output = min(6, max(0, input))
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4.
+ *
+ * Inputs:
+ * * 0: A tensor, specifying the input.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0.
+ */
+ ANEURALNETWORKS_RELU6 = 21,
+ /** Reshapes a tensor.
+ *
+ * Given tensor, this operation returns a tensor that has the same values as
+ * tensor, but with a newly specified shape.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4.
+ *
+ * Inputs:
+ * * 0: A tensor, specifying the tensor to be reshaped.
+ * * 1: A 1-D tensor of type {@link ANEURALNETWORKS_TENSOR_INT32}, defining
+ * the shape of the output tensor. The number of elements implied by shape
+ * must be the same as the number of elements in the input tensor.
+ *
+ * Outputs:
+ * * 0: The output tensor, of shape specified by the input shape.
+ */
+ ANEURALNETWORKS_RESHAPE = 22,
+ /** Resizes images to given size using the bilinear interpretation.
+ *
+ * Resized images will be distorted if their original aspect ratio is not the
+ * same as input.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
+ * input.
+ * * 1: An INT32 value, specifying the output width of the output tensor.
+ * * 2: An INT32 value, specifying the output height of the output tensor.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batches, new_height, new_width,
+ * depth].
+ */
+ ANEURALNETWORKS_RESIZE_BILINEAR = 23,
+
+ /**
+ * A basic recurrent neural network layer.
+ *
+ * This layer implements the operation:
+ * outputs = state = activation(inputs * input_weights + state *
+ * recurrent_weights + bias)
+ *
+ * Where:
+ * * “input_weights” is a weight matrix that multiplies the inputs;
+ * * “recurrent_weights” is a weight matrix that multiplies the current
+ * “state” which itself is the output from the previous time step
+ * computation;
+ * * “bias” is a bias vector (added to each output vector in the batch);
+ * * “activation” is the function passed as the “fused_activation_function”
+ * argument (if not “NONE”).
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Inputs:
+ * * 0: input.
+ * A 2-D tensor of type T, of shape [batch_size, input_size], where
+ * “batch_size” corresponds to the batching dimension, and “input_size”
+ * is the size of the input.
+ * * 1: weights.
+ * A 2-D tensor of type T, of shape [num_units, input_size], where
+ * “num_units” corresponds to the number of units.
+ * * 2: recurrent_weights.
+ * A 2-D tensor of type T, of shape [num_units, num_units], with columns
+ * corresponding to the weights from each unit.
+ * * 3: bias.
+ * A 1-D tensor of type T, of shape [num_units].
+ *
+ * For FLOAT32 input tensor, bias must also be FLOAT32.
+ * For UINT8 input tensor, bias must be INT32.
+ *
+ * Parameters
+ * * 4: fused_activation_function.
+ * An (optional) ActivationFunctionType indicating the activation
+ * function. If “NONE” is specified then it results in a linear
+ * activation.
+ *
+ * * 5: Hidden state.
+ * A 2-D tensor of type T, of shape [batch_size, num_units].
+ *
+ * Outputs:
+ * * 0: output.
+ * A 2-D tensor of type T, of shape [batch_size, num_units]. This is
+ * effectively the same as the current state value.
+ */
+ ANEURALNETWORKS_RNN = 24,
+
+ /** Computes the softmax activation on the input tensor element-wise, per
+ * batch, by normalizing the input vector so the maximum coefficient is zero.
+ *
+ * The output is calculated using this formula:
+ *
+ * output[batch, i] =
+ * exp((input[batch, i] - max(input[batch, :])) * beta) /
+ * sum_{k}{exp((input[batch, k] - max(input[batch, :])) * beta)}
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: 2 or 4.
+ *
+ * Inputs:
+ * * 0: A 2-D or 4-D tensor, specifying the tensor to be reshaped.
+ * * 1: A FLOAT32 value, specifying the scaling factor for the exponent, beta.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0.
+ */
+ ANEURALNETWORKS_SOFTMAX = 25,
+
+ /** Rearranges blocks of spatial data, into depth.
+ *
+ * More specifically, this op outputs a copy of the input tensor where values
+ * from the height and width dimensions are moved to the depth dimension. The
+ * value block_size indicates the input block size and how the data is moved.
+ *
+ * Chunks of data of size block_size * block_size from depth are rearranged
+ * into non-overlapping blocks of size block_size x block_size.
+ *
+ * The depth of the output tensor is input_depth * block_size * block_size.
+ * The input tensor's height and width must be divisible by block_size.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying
+ * the input.
+ * * 1: An INT32 value, specifying the block_size. block_size must be >=1 and
+ * block_size must be a divisor of both the input height and width.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batch, height/block_size,
+ * width/block_size, depth*block_size*block_size].
+ */
+ ANEURALNETWORKS_SPACE_TO_DEPTH = 26,
+
+ /**
+ * SVDF op is a kind of stateful layer derived from the notion that a
+ * densely connected layer that's processing a sequence of input frames can
+ * be approximated by using a singular value decomposition of each of its
+ * nodes. The implementation is based on:
+ *
+ * https://research.google.com/pubs/archive/43813.pdf
+ *
+ * P. Nakkiran, R. Alvarez, R. Prabhavalkar, C. Parada.
+ * “Compressing Deep Neural Networks using a Rank-Constrained Topology”.
+ * INTERSPEECH, 2015.
+ *
+ * It processes the incoming input using a 2-stage filtering mechanism:
+ * * stage 1 performs filtering on the "features" dimension, whose outputs get
+ * pushed into a memory of fixed-size memory_size.
+ * * stage 2 performs filtering on the "time" dimension of the memory_size
+ * memoized outputs of stage 1.
+ *
+ * Specifically, for rank 1, this layer implements the operation:
+ *
+ * memory = push(conv1d(inputs, weights_feature, feature_dim, "VALID"));
+ * outputs = activation(memory * weights_time + bias);
+ *
+ * Where:
+ * * “weights_feature” is a weights matrix that processes the inputs (by
+ * convolving the input with every “feature filter”), and whose outputs get
+ * pushed, stacked in order, into the fixed-size “memory” (the oldest entry
+ * gets dropped);
+ * * “weights_time” is a weights matrix that processes the “memory” (by a
+ * batched matrix multiplication on the num_units);
+ * * “bias” is an optional bias vector (added to each output vector in the
+ * batch); and
+ * * “activation” is the function passed as the “fused_activation_function”
+ * argument (if not “NONE”).
+ *
+ * Each rank adds a dimension to the weights matrices by means of stacking
+ * the filters.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Inputs:
+ * * 0: input.
+ * A 2-D tensor of type T, of shape [batch_size, input_size], where
+ * “batch_size” corresponds to the batching dimension, and “input_size”
+ * is the size of the input.
+ * * 1: weights_feature.
+ * A 2-D tensor of type T, of shape [num_units, input_size], where
+ * “num_units” corresponds to the number of units.
+ * * 2: weights_time.
+ * A 2-D tensor of type T, of shape [num_units, memory_size], where
+ * “memory_size” corresponds to the fixed-size of the memory.
+ * * 3: bias.
+ * A optional 1-D tensor of type T, of shape [num_units].
+ *
+ * For FLOAT32 input tensor, bias must also be FLOAT32.
+ * For UINT8 input tensor, bias must be INT32.
+ *
+ * Parameters:
+ * * 4: rank.
+ * The rank of the SVD approximation.
+ * * 5: fused_activation_function.
+ * An (optional) ActivationFunctionType indicating the activation
+ * function. If “NONE” is specified then it results in a linear activation.
+ *
+ * Outputs:
+ * * 0: state.
+ * A 2-D tensor of type T, of shape [batch_size, (memory_size - 1) *
+ * num_units * rank].
+ * * 1: output.
+ * A 2-D tensor of type T, of shape [batch_size, num_units].
+ */
+ ANEURALNETWORKS_SVDF = 27,
+
+ /** Computes hyperbolic tangent of input tensor element-wise.
+ *
+ * The output is calculated using this formula:
+ *
+ * output = tanh(input)
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: up to 4.
+ *
+ * Inputs:
+ * * 0: A tensor, specifying the input.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0.
+ */
+ ANEURALNETWORKS_TANH = 28,
+};
+
+/**
+ * Fused activation function types.
+ *
+ */
+enum {
+ /** NO fused activation function. */
+ ANEURALNETWORKS_FUSED_NONE = 0,
+ /** Fused ReLU activation function. */
+ ANEURALNETWORKS_FUSED_RELU = 1,
+ /** Fused ReLU1 activation function. */
+ ANEURALNETWORKS_FUSED_RELU1 = 2,
+ /** Fused ReLU6 activation function. */
+ ANEURALNETWORKS_FUSED_RELU6 = 3,
+};
+
+/**
+ * Execution preferences.
+ */
+enum {
+ /**
+ * Prefer executing in a way that minimizes battery drain.
+ * This is desirable for compilations that will be executed often.
+ */
+ ANEURALNETWORKS_PREFER_LOW_POWER = 0,
+ /**
+ * Prefer returning a single answer as fast as possible, even if this causes
+ * more power consumption.
+ */
+ ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1,
+ /**
+ * Prefer maximizing the throughput of successive frames, for example when
+ * processing successive frames coming from the camera.
+ */
+ ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2,
+};
+
+/**
+ * Result codes.
+ */
+enum {
+ ANEURALNETWORKS_NO_ERROR = 0,
+ ANEURALNETWORKS_OUT_OF_MEMORY = 1,
+ ANEURALNETWORKS_INCOMPLETE = 2,
+ ANEURALNETWORKS_UNEXPECTED_NULL = 3,
+ ANEURALNETWORKS_BAD_DATA = 4,
+ ANEURALNETWORKS_OP_FAILED = 5,
+ ANEURALNETWORKS_UNMAPPABLE = 5,
+ ANEURALNETWORKS_BAD_STATE = 6,
+};
+
+/**
+ * ANeuralNetworksMemory is an opaque type that represents memory.
+ *
+ * This type is used to represent shared memory, memory mapped files,
+ * and similar memories.
+ *
+ * By using shared memory, a program can efficiently communicate to the
+ * runtime and drivers the tensors that define a model. See
+ * {@link ANeuralNetworksModel_setOperandValueFromMemory}. An application
+ * should typically create one shared memory object that contains every tensor
+ * needed to define a model. {@link ANeuralNetworksMemory_createFromFd} can be
+ * used to create shared memory from a file handle. {@link
+ * ANeuralNetworksMemory_createShared} can be used to directly created shared
+ * memory.
+ *
+ * Memory objects can also be used to specify the input and output arguments of
+ * an execution. See {@link ANeuralNetworksExecution_setInputFromMemory}
+ * and {@link ANeuralNetworksExecution_setOutputFromMemory}.
+ */
+typedef struct ANeuralNetworksMemory ANeuralNetworksMemory;
+
+/**
+ * ANeuralNetworksModel is an opaque type that contains a description of the
+ * mathematical operations that constitute the model.
+ *
+ * <p>The model will be built by calling<ul>
+ * <li>{@link ANeuralNetworksModel_create},</li>
+ * <li>{@link ANeuralNetworksModel_addOperation},</li>
+ * <li>{@link ANeuralNetworksModel_addOperand},</li>
+ * </ul>
+ *
+ * A model is completed by calling {@link ANeuralNetworksModel_finish}.
+ * A model is destroyed by calling {@link ANeuralNetworksModel_free}.
+ *
+ * <p>It is the application's responsibility to make sure that only one thread
+ * modifies a model at a given time. It is however safe for more than one
+ * thread to use the model once {@link ANeuralNetworksModel_finish} has
+ * returned.</p>
+ *
+ * <p>It is also the application's responsibility to ensure that there are no
+ * other uses of the model after calling {@link ANeuralNetworksModel_free}. This
+ * includes any compilation or execution object created using the model.</p>
+ */
+typedef struct ANeuralNetworksModel ANeuralNetworksModel;
+
+/**
+ * ANeuralNetworksCompilation is an opaque type that can be used to compile
+ * a machine learning model.
+ *
+ * <p>To use:<ul>
+ * <li>Create a new compilation instance by calling the
+ * {@link ANeuralNetworksCompilation_create} function.</li>
+ * <li>Perform the compilation with {@link
+ * ANeuralNetworksCompilation_start}.</li> <li>Wait for the compilation to
+ * complete with {@link ANeuralNetworksCompilation_wait}.</li> <li>Use the
+ * compilation as many times as needed with {@link
+ * ANeuralNetworksExecution_create}.</li> <li>Destroy the compilation with
+ * {@link ANeuralNetworksCompilation_free} once all executions using the
+ * compilation have completed.</li></ul></p>
+ *
+ * <p>A compilation cannot be modified once {@link
+ * ANeuralNetworksCompilation_start} has been called on it.</p>
+ *
+ * <p>It is the application's responsibility to make sure that only one thread
+ * modifies a compilation at a given time. It is however safe for more than one
+ * thread to use {@link ANeuralNetworksCompilation_wait} at the same time.
+ * It is also safe for multiple threads to use a compilation object once
+ * {@link ANeuralNetworksCompilation_wait} has completed.</p>
+ *
+ * <p>It is also the application's responsibility to ensure that there are no
+ * other uses of the compilation after calling {@link
+ * ANeuralNetworksCompilation_free}. This includes any execution object created
+ * using the compilation.</p>
+ */
+typedef struct ANeuralNetworksCompilation ANeuralNetworksCompilation;
+
+/**
+ * ANeuralNetworksExecution is an opaque type that can be used to apply a
+ * machine learning model to a set of inputs.
+ *
+ * <p>To use:<ul>
+ * <li>Create a new execution instance by calling the
+ * {@link ANeuralNetworksExecution_create} function.</li>
+ * <li>Associate data to the model inputs with
+ * {@link ANeuralNetworksExecution_setInput} or
+ * {@link ANeuralNetworksExecution_setInputFromMemory}.</li>
+ * <li>Associate output buffers to the model outputs with
+ * {@link ANeuralNetworksExecution_setOutput} or
+ * {@link ANeuralNetworksExecution_setOutputFromMemory}.</li>
+ * <li>Apply the model with {@link
+ * ANeuralNetworksExecution_startCompute}.</li> <li>Wait for the execution to
+ * complete with {@link ANeuralNetworksExecution_wait}.</li> <li>Destroy the
+ * execution with
+ * {@link ANeuralNetworksExecution_free}.</li></ul></p>
+ *
+ * <p>An execution cannot be modified once {@link
+ * ANeuralNetworksExecution_start} has been called on it.</p>
+ *
+ * <p>An execution can be applied to a model with
+ * {@link ANeuralNetworksExecution_startCompute} only once. Create new
+ * executions to do new evaluations of the model.</p>
+ *
+ * <p>It is the application's responsibility to make sure that only one thread
+ * modifies an execution at a given time. It is however safe for more than one
+ * thread to use {@link ANeuralNetworksExecution_wait} at the same time.</p>
+ *
+ * <p>It is also the application's responsibility to ensure that there are no
+ * other uses of the request after calling {@link
+ * ANeuralNetworksRequest_free}.</p>
+ */
+typedef struct ANeuralNetworksExecution ANeuralNetworksExecution;
+
+/**
+ * ANeuralNetworksOperandType describes the type of an operand.
+ * This structure is used to describe both scalars and tensors.
+ */
+typedef struct ANeuralNetworksOperandType {
+ /** The data type, e.g ANEURALNETWORKS_INT8. */
+ int32_t type;
+ /** The number of dimensions. It should be 0 for scalars. */
+ uint32_t dimensionCount;
+ /** The dimensions of the tensor. It should be nullptr for scalars. */
+ const uint32_t* dimensions;
+ /** These two fields are only used for quantized tensors.
+ * They should be zero for scalars and non-fixed point tensors.
+ * The dequantized value of each entry is (value - offset) * scale.
+ */
+ float scale;
+ int32_t zeroPoint;
+} ANeuralNetworksOperandType;
+
+/**
+ * ANeuralNetworksEvent is an opaque type that represents an event
+ * that will be signaled once an execution completes.
+ */
+typedef struct ANeuralNetworksEvent ANeuralNetworksEvent;
+
+typedef int32_t ANeuralNetworksOperationType;
+
+// nn api function types
+
+typedef int (*ANeuralNetworksMemory_createFromFd_fn)(
+ size_t size, int protect, int fd, size_t offset,
+ ANeuralNetworksMemory** memory);
+
+typedef void (*ANeuralNetworksMemory_free_fn)(ANeuralNetworksMemory* memory);
+
+typedef int (*ANeuralNetworksModel_create_fn)(ANeuralNetworksModel** model);
+
+typedef int (*ANeuralNetworksModel_finish_fn)(ANeuralNetworksModel* model);
+
+typedef void (*ANeuralNetworksModel_free_fn)(ANeuralNetworksModel* model);
+
+typedef int (*ANeuralNetworksCompilation_create_fn)(
+ ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation);
+
+typedef void (*ANeuralNetworksCompilation_free_fn)(
+ ANeuralNetworksCompilation* compilation);
+
+typedef int (*ANeuralNetworksCompilation_setPreference_fn)(
+ ANeuralNetworksCompilation* compilation, int32_t preference);
+
+typedef int (*ANeuralNetworksCompilation_finish_fn)(
+ ANeuralNetworksCompilation* compilation);
+
+typedef int (*ANeuralNetworksModel_addOperand_fn)(
+ ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type);
+
+typedef int (*ANeuralNetworksModel_setOperandValue_fn)(
+ ANeuralNetworksModel* model, int32_t index, const void* buffer,
+ size_t length);
+
+typedef int (*ANeuralNetworksModel_setOperandValueFromMemory_fn)(
+ ANeuralNetworksModel* model, int32_t index,
+ const ANeuralNetworksMemory* memory, size_t offset, size_t length);
+
+typedef int (*ANeuralNetworksModel_addOperation_fn)(
+ ANeuralNetworksModel* model, ANeuralNetworksOperationType type,
+ uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount,
+ const uint32_t* outputs);
+
+typedef int (*ANeuralNetworksModel_identifyInputsAndOutputs_fn)(
+ ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs,
+ uint32_t outputCount, const uint32_t* outputs);
+
+typedef int (*ANeuralNetworksExecution_create_fn)(
+ ANeuralNetworksCompilation* compilation,
+ ANeuralNetworksExecution** execution);
+
+typedef void (*ANeuralNetworksExecution_free_fn)(
+ ANeuralNetworksExecution* execution);
+
+typedef int (*ANeuralNetworksExecution_setInput_fn)(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, const void* buffer, size_t length);
+
+typedef int (*ANeuralNetworksExecution_setInputFromMemory_fn)(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory,
+ size_t offset, size_t length);
+
+typedef int (*ANeuralNetworksExecution_setOutput_fn)(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, void* buffer, size_t length);
+
+typedef int (*ANeuralNetworksExecution_setOutputFromMemory_fn)(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory,
+ size_t offset, size_t length);
+
+typedef int (*ANeuralNetworksExecution_startCompute_fn)(
+ ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event);
+
+typedef int (*ANeuralNetworksEvent_wait_fn)(ANeuralNetworksEvent* event);
+
+typedef void (*ANeuralNetworksEvent_free_fn)(ANeuralNetworksEvent* event);
+
+/**
+ * Creates a shared memory object from a file descriptor.
+ *
+ * The shared memory is backed by a file descriptor via mmap.
+ * See {@link ANeuralNetworksMemory} for a description on how to use
+ * this shared memory.
+ *
+ * @param size The requested size in bytes.
+ * Must not be larger than the file size.
+ * @param prot The desired memory protection for the mapping.
+ * It is either PROT_NONE or the bitwise OR of one or
+ * more of the following flags: PROT_READ, PROT_WRITE.
+ * @param fd The requested file descriptor.
+ * The file descriptor has to be mmap-able. The file
+ * descriptor will be duplicated.
+ * @param offset The offset to the beginning of the file of the area to map.
+ * The offset has to be aligned to a page size.
+ * @param memory The memory object to be created.
+ * Set to NULL if unsuccessful.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if the request completed normally.
+ */
+inline int ANeuralNetworksMemory_createFromFd(size_t size, int protect, int fd,
+ size_t offset,
+ ANeuralNetworksMemory** memory) {
+ LOAD_FUNCTION(ANeuralNetworksMemory_createFromFd);
+ EXECUTE_FUNCTION_RETURN(size, protect, fd, offset, memory);
+}
+
+/**
+ * Delete a memory object.
+ *
+ * Destroys the object used by the run time to keep track of the memory.
+ * This will free the underlying actual memory if no other code has open
+ * handles to this memory.
+ *
+ * @param memory The memory object to be freed.
+ */
+inline void ANeuralNetworksMemory_free(ANeuralNetworksMemory* memory) {
+ LOAD_FUNCTION(ANeuralNetworksMemory_free);
+ EXECUTE_FUNCTION(memory);
+}
+
+/**
+ * Create an empty {@link ANeuralNetworksModel}.
+ *
+ * <p>This only creates the object. Computation is performed once
+ * {@link ANeuralNetworksExecution_startCompute} is invoked.
+ *
+ * The model should be constructed with calls to
+ * {@link ANeuralNetworksModel_addOperation} and
+ * {@link ANeuralNetworksModel_addOperand}
+ *
+ * <p>{@link ANeuralNetworksModel_finish} should be called once the model
+ * has been fully constructed.</p>
+ *
+ * <p>{@link ANeuralNetworksModel_free} should be called once the model
+ * is no longer needed.</p>
+ *
+ * @param model The {@link ANeuralNetworksModel} to be created.
+ * Set to NULL if unsuccessful.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksModel_create(ANeuralNetworksModel** model) {
+ LOAD_FUNCTION(ANeuralNetworksModel_create);
+ EXECUTE_FUNCTION_RETURN(model);
+}
+
+/**
+ * Destroy a model.
+ *
+ * The model need not have been finished by a call to
+ * {@link ANeuralNetworksModel_finish}.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ * @param model The model to be destroyed. Passing NULL is acceptable and
+ * results in no operation.
+ */
+inline void ANeuralNetworksModel_free(ANeuralNetworksModel* model) {
+ LOAD_FUNCTION(ANeuralNetworksModel_free);
+ EXECUTE_FUNCTION(model);
+}
+
+/**
+ * Indicate that we have finished modifying a model. Required before
+ * calling {@link ANeuralNetworksCompilation_compile}.
+ *
+ * An application is responsible to make sure that no other thread uses
+ * the model at the same time.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ * @param model The model to be finished.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksModel_finish(ANeuralNetworksModel* model) {
+ LOAD_FUNCTION(ANeuralNetworksModel_finish);
+ EXECUTE_FUNCTION_RETURN(model);
+}
+
+/**
+ * Add an operand to a model.
+ *
+ * The order in which the operands are added is important. The first one added
+ * to a model will have the index value 0, the second 1, etc. These indexes are
+ * used as operand identifiers in {@link ANeuralNetworksModel_addOperation},
+ * {@link ANeuralNetworksExecution_setInput},
+ * {@link ANeuralNetworksExecution_setInputFromMemory},
+ * {@link ANeuralNetworksExecution_setOutput},
+ * {@link ANeuralNetworksExecution_setOutputFromMemory} and
+ * {@link ANeuralNetworksExecution_setOperandValue}.
+ *
+ * To build a model that can accomodate inputs of various sizes, as you may want
+ * to do for a CNN, set the size of the dimensions that will vary at run time to
+ * 0. If you do so, provide the full dimensions when calling
+ * {@link ANeuralNetworksExecution_setInput} or {@link
+ * ANeuralNetworksExecution_setInputFromMemory}.
+ *
+ * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has
+ * been called will return an error.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ * @param model The model to be modified.
+ * @param type The {@link ANeuralNetworksOperandType} that describes the shape
+ * of the operand.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksModel_addOperand(
+ ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type) {
+ LOAD_FUNCTION(ANeuralNetworksModel_addOperand);
+ EXECUTE_FUNCTION_RETURN(model, type);
+}
+
+/**
+ * Sets an operand to a constant value.
+ *
+ * For scalar values, the content of buffer is copied into the model.
+ *
+ * For tensor values, a pointer to the buffer is stored within the model.
+ * The application is responsible for not changing the content of this region
+ * until all executions using this model have completed. As the data may
+ * be copied during processing, modifying the data after this call yields
+ * undefined results.
+ *
+ * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has
+ * been called will return an error.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ * @param model The model to be modified.
+ * @param index The index of the model operand we're setting.
+ * @param buffer A pointer to the data to use.
+ * @param length The size in bytes of the data value.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel* model,
+ int32_t index,
+ const void* buffer,
+ size_t length) {
+ LOAD_FUNCTION(ANeuralNetworksModel_setOperandValue);
+ EXECUTE_FUNCTION_RETURN(model, index, buffer, length);
+}
+
+/**
+ * Sets an operand to a value stored in a memory object.
+ *
+ * The content of the memory is not copied. A reference to that memory is stored
+ * inside the model. The application is responsible for not changing the content
+ * of the memory region until all executions using this model have completed.
+ * As the data may be copied during processing, modifying the data after this
+ * call yields undefined results.
+ *
+ * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has
+ * been called will return an error.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ * @param model The model to be modified.
+ * @param index The index of the model operand we're setting.
+ * @param buffer A pointer to the data to use.
+ * @param memory The memory containing the data.
+ * @param offset This specifies the location of the data within the memory.
+ * The offset is in bytes from the start of memory.
+ * @param length The size in bytes of the data value.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksModel_setOperandValueFromMemory(
+ ANeuralNetworksModel* model, int32_t index,
+ const ANeuralNetworksMemory* memory, size_t offset, size_t length) {
+ LOAD_FUNCTION(ANeuralNetworksModel_setOperandValueFromMemory);
+ EXECUTE_FUNCTION_RETURN(model, index, memory, offset, length);
+}
+
+/**
+ * Add an operation to a model.
+ *
+ * @param model The model to be modified.
+ * @param type The type of the operation.
+ * @param inputCount The number of entries in the inputs array.
+ * @param inputs An array of indexes identifying each operand.
+ * @param outputCount The number of entries in the outputs array.
+ * @param outputs An array of indexes identifying each operand.
+ *
+ * The operands specified by inputs and outputs must have been
+ * previously added by calls to {@link ANeuralNetworksModel_addOperand}.
+ *
+ * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has
+ * been called will return an error.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksModel_addOperation(ANeuralNetworksModel* model,
+ ANeuralNetworksOperationType type,
+ uint32_t inputCount,
+ const uint32_t* inputs,
+ uint32_t outputCount,
+ const uint32_t* outputs) {
+ LOAD_FUNCTION(ANeuralNetworksModel_addOperation);
+ EXECUTE_FUNCTION_RETURN(model, type, inputCount, inputs, outputCount,
+ outputs);
+}
+
+/**
+ * Specfifies which operands will be the model's inputs and outputs.
+ *
+ * An operand cannot be used for both input and output. Doing so will
+ * return an error.
+ *
+ * @param model The model to be modified.
+ * @param inputCount The number of entries in the inputs array.
+ * @param inputs An array of indexes identifying the input operands.
+ * @param outputCount The number of entries in the outputs array.
+ * @param outputs An array of indexes identifying the output operands.
+ *
+ * The operands specified by inputs and outputs must have been
+ * previously added by calls to {@link ANeuralNetworksModel_addOperand}.
+ *
+ * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has
+ * been called will return an error.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ */
+inline int ANeuralNetworksModel_identifyInputsAndOutputs(
+ ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs,
+ uint32_t outputCount, const uint32_t* outputs) {
+ LOAD_FUNCTION(ANeuralNetworksModel_identifyInputsAndOutputs);
+ EXECUTE_FUNCTION_RETURN(model, inputCount, inputs, outputCount, outputs);
+}
+
+/**
+ * Create a {@link ANeuralNetworksCompilation} to compile the given model.
+ * This only creates the object. Compilation is only performed once
+ * {@link ANeuralNetworksCompilation_start} is invoked.
+ *
+ * <p>The provided model must outlive the compilation.</p>
+ *
+ * The model must already have been finished by a call to
+ * {@link ANeuralNetworksModel_finish}.
+ *
+ * See {@link ANeuralNetworksCompilation} for information on multithreaded
+ * usage.
+ *
+ * @param model The {@link ANeuralNetworksModel} to be compiled.
+ * @param compilation The newly created object or NULL if unsuccessful.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA
+ * if the model is invalid.
+ */
+inline int ANeuralNetworksCompilation_create(
+ ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation) {
+ LOAD_FUNCTION(ANeuralNetworksCompilation_create);
+ EXECUTE_FUNCTION_RETURN(model, compilation);
+}
+
+/**
+ * Destroy a compilation.
+ *
+ * <p>If called on a compilation for which
+ * {@link ANeuralNetworksCompilation_start} has been called, the
+ * function will return immediately but will mark the compilation to be deleted
+ * once the compilation completes. The {@link ANeuralNetworksCompilation_wait}
+ * will return ERROR_DELETED.
+ *
+ * See {@link ANeuralNetworksCompilation} for information on multithreaded
+ * usage.
+ *
+ * @param compilation The compilation to be destroyed. Passing NULL is
+ * acceptable and results in no operation.
+ */
+inline void ANeuralNetworksCompilation_free(
+ ANeuralNetworksCompilation* compilation) {
+ LOAD_FUNCTION(ANeuralNetworksCompilation_free);
+ EXECUTE_FUNCTION(compilation);
+}
+
+/**
+ * Sets the execution preference.
+ *
+ * <p>Provides guidance to the runtime when trade-offs are possible.</p>
+ *
+ * See {@link ANeuralNetworksCompilation} for information on multithreaded
+ * usage.
+ *
+ * @param compilation The compilation to be modified.
+ * @param preference Either {@link PREFER_LOW_POWER},
+ * {@link PREFER_SINGLE_FAST_ANSWER}, or
+ * {@link PREFER_SUSTAINED_SPEED}.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksCompilation_setPreference(
+ ANeuralNetworksCompilation* compilation, int32_t preference) {
+ LOAD_FUNCTION(ANeuralNetworksCompilation_setPreference);
+ EXECUTE_FUNCTION_RETURN(compilation, preference);
+}
+
+/**
+ * Waits until the compilation completes.
+ *
+ * More than one thread can wait on a compilation. When the compilation
+ * completes, all threads will be released.
+ *
+ * See {@link ANeuralNetworksCompilation} for information on multithreaded
+ * usage.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if the compilation completed normally.
+ */
+inline int ANeuralNetworksCompilation_finish(
+ ANeuralNetworksCompilation* compilation) {
+ LOAD_FUNCTION(ANeuralNetworksCompilation_finish);
+ EXECUTE_FUNCTION_RETURN(compilation);
+}
+/**
+ * Create a {@link ANeuralNetworksExecution} to apply the given compilation.
+ * This only creates the object. Computation is only performed once
+ * {@link ANeuralNetworksExecution_startCompute} is invoked.
+ *
+ * <p>The provided compilation must outlive the execution.</p>
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @param compilation The {@link ANeuralNetworksCompilation} to be evaluated.
+ * @param execution The newly created object or NULL if unsuccessful.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA
+ * if the compilation is invalid.
+ */
+inline int ANeuralNetworksExecution_create(
+ ANeuralNetworksCompilation* compilation,
+ ANeuralNetworksExecution** execution) {
+ LOAD_FUNCTION(ANeuralNetworksExecution_create);
+ EXECUTE_FUNCTION_RETURN(compilation, execution);
+}
+
+/**
+ * Destroy an execution.
+ *
+ * <p>If called on an execution for which
+ * {@link ANeuralNetworksExecution_startCompute} has been called, the
+ * function will return immediately but will mark the execution to be deleted
+ * once the computation completes. The {link ANeuralNetworksExecution_wait}
+ * will return ANEURALNETWORKS_ERROR_DELETED.
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @param execution The execution to be destroyed. Passing NULL is acceptable
+ * and results in no operation.
+ */
+inline void ANeuralNetworksExecution_free(ANeuralNetworksExecution* execution) {
+ LOAD_FUNCTION(ANeuralNetworksExecution_free);
+ EXECUTE_FUNCTION(execution);
+}
+
+/**
+ * Associate a user buffer with an input of the model of the
+ * {@link ANeuralNetworksExecution}.
+ *
+ * <p>The provided buffer must outlive the execution.</p>
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @param execution The execution to be modified.
+ * @param index The index of the input argument we are setting. It is
+ * an index into the lists passed to
+ * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not
+ * the index associated with {@link
+ * ANeuralNetworksModel_addOperand}.
+ * @param type The type of the operand. This should be used to specify the
+ * dimensions that were set to 0 when the operand was added to the
+ * model. All other properties of the type must be the same as
+ * specified in the model. If the type is the same as specified
+ * when the model was built, NULL can be passed.
+ * @param buffer The buffer containing the data.
+ * @param length The length in bytes of the buffer.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if
+ * the name is not recognized or the buffer is too small for the input.
+ */
+inline int ANeuralNetworksExecution_setInput(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, const void* buffer, size_t length) {
+ LOAD_FUNCTION(ANeuralNetworksExecution_setInput);
+ EXECUTE_FUNCTION_RETURN(execution, index, type, buffer, length);
+}
+
+/**
+ * Associate part of a memory object with an input of the model of the
+ * {@link ANeuralNetworksExecution}.
+ *
+ * <p>The provided memory must outlive the execution.</p>
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @param execution The execution to be modified.
+ * @param index The index of the input argument we are setting. It is
+ * an index into the lists passed to
+ * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not
+ * the index associated with {@link
+ * ANeuralNetworksModel_addOperand}.
+ * @param type The type of the operand. This can be used to specify the
+ * dimensions that were set to 0 when the operand was added to the
+ * model. All other values must be the same as specified in the
+ * model. If the type is the same as specified when the model
+ * was built, NULL can be passed.
+ * @param memory The memory containing the data.
+ * @param offset This specifies the location of the data whithin the memory.
+ * The offset is in bytes from the start of memory.
+ * @param length The size in bytes of the data value.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if
+ * the name is not recognized or the buffer is too small for the input.
+ */
+inline int ANeuralNetworksExecution_setInputFromMemory(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory,
+ size_t offset, size_t length) {
+ LOAD_FUNCTION(ANeuralNetworksExecution_setInputFromMemory);
+ EXECUTE_FUNCTION_RETURN(execution, index, type, memory, offset, length);
+}
+
+/**
+ * Associate a user buffer with an output of the model of the
+ * {@link ANeuralNetworksExecution}.
+ *
+ * <p>The provided buffer must outlive the execution.</p>
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @param execution The execution to be modified.
+ * @param index The index of the output argument we are setting. It is
+ * an index into the lists passed to
+ * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not
+ * the index associated with {@link
+ * ANeuralNetworksModel_addOperand}.
+ * @param type The type of the operand. This can be used to specify the
+ * dimensions that were set to 0 when the operand was added to the
+ * model. All other values must be the same as specified in the
+ * model. If the type is the same as specified when the model
+ * was built, NULL can be passed.
+ * @param buffer The buffer where the data is to be written.
+ * @param length The length in bytes of the buffer.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if
+ * the name is not recognized or the buffer is too small for the output.
+ */
+inline int ANeuralNetworksExecution_setOutput(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, void* buffer, size_t length) {
+ LOAD_FUNCTION(ANeuralNetworksExecution_setOutput);
+ EXECUTE_FUNCTION_RETURN(execution, index, type, buffer, length);
+}
+
+/**
+ * Associate part of a memory object with an output of the model of the
+ * {@link ANeuralNetworksExecution}.
+ *
+ * <p>The provided memory must outlive the execution.</p>
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @param execution The execution to be modified.
+ * @param index The index of the output argument we are setting. It is
+ * an index into the lists passed to
+ * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not
+ * the index associated with {@link
+ * ANeuralNetworksModel_addOperand}.
+ * @param type The type of the operand. This can be used to specify the
+ * dimensions that were set to 0 when the operand was added to the
+ * model. All other values must be the same as specified in the
+ * model. If the type is the same as specified when the model
+ * was built, NULL can be passed.
+ * @param memory The memory where the data is to be stored.
+ * @param offset This specifies the location of the data whithin the memory.
+ * The offset is in bytes from the start of memory.
+ * @param length The length in bytes of the data value.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if
+ * the name is not recognized or the buffer is too small for the output.
+ */
+inline int ANeuralNetworksExecution_setOutputFromMemory(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory,
+ size_t offset, size_t length) {
+ LOAD_FUNCTION(ANeuralNetworksExecution_setOutputFromMemory);
+ EXECUTE_FUNCTION_RETURN(execution, index, type, memory, offset, length);
+}
+
+/**
+ * Schedule evaluation of the execution.
+ *
+ * <p>Schedules evaluation of the execution. Once the model has been
+ * applied and the outputs are ready to be consumed, the execution will be
+ * signaled. Use {@link ANeuralNetworksExecution_wait} to wait for that signal.
+ * </p>
+ *
+ * Multiple executions can be scheduled and evaluated concurrently, and
+ * compilations can be performed concurrently with executions. The runtime makes
+ * no guarantee on the ordering of the completion of compilations and
+ * executions. If it's important to the application, the application should
+ * enforce the ordering by using {@link ANeuralNetworksCompilation_wait} and
+ * {@link ANeuralNetworksExecution_wait}.
+ *
+ * ANeuralNetworksExecution_wait must be called to recuperate the resources used
+ * by the execution.
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @param execution The execution to be scheduled and executed.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksExecution_startCompute(
+ ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event) {
+ LOAD_FUNCTION(ANeuralNetworksExecution_startCompute);
+ EXECUTE_FUNCTION_RETURN(execution, event);
+}
+
+/**
+ * Waits until the execution completes.
+ *
+ * More than one thread can wait on an event. When the execution completes,
+ * all threads will be released.
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if the execution completed normally.
+ */
+inline int ANeuralNetworksEvent_wait(ANeuralNetworksEvent* event) {
+ LOAD_FUNCTION(ANeuralNetworksEvent_wait);
+ EXECUTE_FUNCTION_RETURN(event);
+}
+
+/**
+ * Destroys the event.
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ */
+inline void ANeuralNetworksEvent_free(ANeuralNetworksEvent* event) {
+ LOAD_FUNCTION(ANeuralNetworksEvent_free);
+ EXECUTE_FUNCTION(event);
+}
+
+/**/
+
+#endif // NN_API_SHIM_H0
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
new file mode 100644
index 0000000000..6a199cc840
--- /dev/null
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -0,0 +1,386 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/nnapi_delegate.h"
+#include <fcntl.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
+
+namespace tflite {
+
+// TODO(aselle): FATAL leaves resources hanging.
+void FATAL(const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ vfprintf(stderr, format, args);
+ va_end(args);
+ fflush(stderr);
+ exit(1);
+}
+
+// TODO(aselle): Change the error model to use status codes.
+#define CHECK_TFLITE_SUCCESS(x) \
+ if (x != kTfLiteOk) { \
+ FATAL("Aborting since tflite returned failure."); \
+ }
+
+#define CHECK_NN(x) \
+ if (x != ANEURALNETWORKS_NO_ERROR) { \
+ FATAL("Aborting since tflite returned failure."); \
+ }
+
+NNAPIAllocation::NNAPIAllocation(const char* filename,
+ ErrorReporter* error_reporter)
+ : MMAPAllocation(filename, error_reporter) {
+ if (mmapped_buffer_ != MAP_FAILED)
+ CHECK_NN(ANeuralNetworksMemory_createFromFd(buffer_size_bytes_, PROT_READ,
+ mmap_fd_, 0, &handle_));
+}
+
+NNAPIAllocation::~NNAPIAllocation() {
+ if (handle_) {
+ ANeuralNetworksMemory_free(handle_);
+ }
+}
+
+NNAPIDelegate::~NNAPIDelegate() {
+ if (nn_model_) {
+ ANeuralNetworksModel_free(nn_model_);
+ nn_model_ = nullptr;
+ // TODO(aselle): Is this thread-safe and callable multiple times?
+ }
+ // ANeuralNetworksShutdown();
+}
+
+// Adds the tensors of the interpreter to the NN API model.
+// Returns the number of operands added.
+uint32_t addTensorOperands(tflite::Interpreter* interpreter,
+ ANeuralNetworksModel* nn_model) {
+ uint32_t next_id = 0;
+ for (size_t i = 0; i < interpreter->tensors_size(); i++) {
+ int32_t nn_type = 0;
+ float scale = 1.0f;
+ int32_t zeroPoint = 0;
+ TfLiteTensor* tensor = interpreter->tensor(i);
+ switch (tensor->type) {
+ case kTfLiteNoType:
+ // Tensors added during initialization of Ops don't have a type yet and
+ // should not be registered with the NNAPI.
+ continue;
+ case kTfLiteFloat32:
+ nn_type = ANEURALNETWORKS_TENSOR_FLOAT32;
+ break;
+ case kTfLiteUInt8:
+ nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM;
+ scale = tensor->params.scale;
+ zeroPoint = tensor->params.zero_point;
+ break;
+ case kTfLiteInt32:
+ nn_type = ANEURALNETWORKS_TENSOR_INT32;
+ scale = tensor->params.scale;
+ zeroPoint = tensor->params.zero_point;
+ break;
+ default:
+ FATAL("Unsupported type.");
+ }
+ // TODO(aselle): Note, many of these are intermediate results. Do I need
+ // to ever specify these sizes. I am currently below doing setValue
+ // on all of them, but I shouldn't in the future.
+ // Answer(jeanluc): If all the operators can set the dimension correctly,
+ // you won't need to.
+ ANeuralNetworksOperandType operand_type{
+ nn_type, static_cast<uint32_t>(tensor->dims->size),
+ reinterpret_cast<uint32_t*>(tensor->dims->data), scale, zeroPoint};
+ CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type));
+
+ // TODO(aselle): Based on Michael's suggestion, limiting this to read
+ // only memory
+ if (tensor->allocation_type == kTfLiteMmapRo) {
+ if (const NNAPIAllocation* alloc = dynamic_cast<const NNAPIAllocation*>(
+ static_cast<const Allocation*>(tensor->allocation))) {
+ CHECK_NN(ANeuralNetworksModel_setOperandValueFromMemory(
+ nn_model, i, alloc->memory(), alloc->offset(tensor->data.raw),
+ tensor->bytes));
+ } else {
+ CHECK_NN(ANeuralNetworksModel_setOperandValue(
+ nn_model, i, tensor->data.raw, tensor->bytes));
+ }
+ }
+ ++next_id;
+ }
+ return next_id;
+}
+
+// Adds the operations and their parameters to the NN API model.
+// 'next-id' is the operand ID of the next operand of the model.
+void AddOpsAndParams(tflite::Interpreter* interpreter,
+ ANeuralNetworksModel* nn_model, uint32_t next_id) {
+ for (size_t i = 0; i < interpreter->nodes_size(); i++) {
+ const auto* node_and_registration = interpreter->node_and_registration(i);
+ const TfLiteNode& node = node_and_registration->first;
+ const TfLiteRegistration& registration = node_and_registration->second;
+ tflite::BuiltinOperator builtin =
+ static_cast<tflite::BuiltinOperator>(registration.builtin_code);
+
+ // Add the parameters.
+ std::vector<uint32_t> augmented_inputs(
+ node.inputs->data, node.inputs->data + node.inputs->size);
+
+ auto add_scalar_int32 = [&nn_model, &augmented_inputs,
+ &next_id](int value) {
+ ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32};
+ CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
+ CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, &value,
+ sizeof(int32_t)))
+ augmented_inputs.push_back(next_id++);
+ };
+
+ auto add_scalar_float32 = [&nn_model, &augmented_inputs,
+ &next_id](float value) {
+ ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_FLOAT32};
+ CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
+ CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, &value,
+ sizeof(float)))
+ augmented_inputs.push_back(next_id++);
+ };
+
+ auto add_add_params = [&add_scalar_int32]() { add_scalar_int32(0); };
+
+ auto add_pooling_params = [&add_scalar_int32](void* data) {
+ auto builtin = reinterpret_cast<TfLitePoolParams*>(data);
+ add_scalar_int32(builtin->padding);
+ add_scalar_int32(builtin->stride_width);
+ add_scalar_int32(builtin->stride_height);
+ add_scalar_int32(builtin->filter_width);
+ add_scalar_int32(builtin->filter_height);
+ add_scalar_int32(builtin->activation);
+ };
+
+ auto add_convolution_params = [&add_scalar_int32](void* data) {
+ auto builtin = reinterpret_cast<TfLiteConvParams*>(data);
+ add_scalar_int32(builtin->padding);
+ add_scalar_int32(builtin->stride_width);
+ add_scalar_int32(builtin->stride_height);
+ add_scalar_int32(builtin->activation);
+ };
+
+ auto add_depthwise_conv_params = [&add_scalar_int32](void* data) {
+ auto builtin = reinterpret_cast<TfLiteDepthwiseConvParams*>(data);
+ add_scalar_int32(builtin->padding);
+ add_scalar_int32(builtin->stride_width);
+ add_scalar_int32(builtin->stride_height);
+ add_scalar_int32(builtin->depth_multiplier);
+ add_scalar_int32(builtin->activation);
+ };
+
+ auto add_fully_connected_params = [&add_scalar_int32](void* data) {
+ auto builtin = reinterpret_cast<TfLiteFullyConnectedParams*>(data);
+ add_scalar_int32(builtin->activation);
+ };
+
+ auto add_concatenation_params = [&add_scalar_int32](void* data) {
+ auto builtin = reinterpret_cast<TfLiteConcatenationParams*>(data);
+ add_scalar_int32(builtin->axis);
+ if (builtin->activation != kTfLiteActNone) {
+ FATAL("Concatenation does not support fused activation in NNAPI");
+ }
+ };
+
+ auto add_softmax_params = [&add_scalar_float32](void* data) {
+ auto builtin = reinterpret_cast<TfLiteSoftmaxParams*>(data);
+ add_scalar_float32(builtin->beta);
+ };
+
+#if 0
+ auto add_reshape_params = [&](void* data) {
+ auto builtin = reinterpret_cast<TfLiteReshapeParams*>(data);
+ uint32_t tensor_size_shape = builtin->num_dimensions;
+ ANeuralNetworksOperandType operand_type{
+ ANEURALNETWORKS_TENSOR_INT32,
+ {static_cast<uint32_t>(1),
+ reinterpret_cast<uint32_t*>(&tensor_size_shape)},
+ 0,
+ 0};
+ CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
+ CHECK_NN(ANeuralNetworksModel_setOperandValue(
+ nn_model, next_id, builtin->shape,
+ sizeof(int) * builtin->num_dimensions));
+ augmented_inputs.push_back(next_id++);
+ };
+#endif
+
+ ANeuralNetworksOperationType nn_op_type;
+ switch (builtin) {
+ case tflite::BuiltinOperator_ADD:
+ nn_op_type = ANEURALNETWORKS_ADD;
+ add_add_params();
+ break;
+ case tflite::BuiltinOperator_AVERAGE_POOL_2D:
+ add_pooling_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_AVERAGE_POOL_2D;
+ break;
+ case tflite::BuiltinOperator_MAX_POOL_2D:
+ add_pooling_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_MAX_POOL_2D;
+ break;
+ case tflite::BuiltinOperator_L2_POOL_2D:
+ add_pooling_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_L2_POOL_2D;
+ break;
+ case tflite::BuiltinOperator_CONV_2D:
+ add_convolution_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_CONV_2D;
+ break;
+ case tflite::BuiltinOperator_RELU:
+ nn_op_type = ANEURALNETWORKS_RELU;
+ break;
+ case tflite::BuiltinOperator_RELU6:
+ nn_op_type = ANEURALNETWORKS_RELU6;
+ break;
+ case tflite::BuiltinOperator_TANH:
+ nn_op_type = ANEURALNETWORKS_TANH;
+ break;
+ case tflite::BuiltinOperator_LOGISTIC:
+ nn_op_type = ANEURALNETWORKS_LOGISTIC;
+ break;
+ case tflite::BuiltinOperator_DEPTHWISE_CONV_2D:
+ add_depthwise_conv_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_DEPTHWISE_CONV_2D;
+ break;
+ case tflite::BuiltinOperator_CONCATENATION:
+ add_concatenation_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_CONCATENATION;
+ break;
+ case tflite::BuiltinOperator_SOFTMAX:
+ add_softmax_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_SOFTMAX;
+ break;
+ case tflite::BuiltinOperator_FULLY_CONNECTED:
+ add_fully_connected_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_FULLY_CONNECTED;
+ break;
+ case tflite::BuiltinOperator_RESHAPE:
+ nn_op_type = ANEURALNETWORKS_RESHAPE;
+ // add_reshape_params(node.builtin_data);
+ break;
+ case tflite::BuiltinOperator_CONCAT_EMBEDDINGS:
+ case tflite::BuiltinOperator_LSH_PROJECTION:
+ case tflite::BuiltinOperator_SVDF:
+ case tflite::BuiltinOperator_HASHTABLE_LOOKUP:
+ case tflite::BuiltinOperator_RNN:
+ case tflite::BuiltinOperator_EMBEDDING_LOOKUP:
+ case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE:
+ case tflite::BuiltinOperator_LSTM:
+ case tflite::BuiltinOperator_L2_NORMALIZATION:
+ case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION:
+ case tflite::BuiltinOperator_MUL:
+ case tflite::BuiltinOperator_RESIZE_BILINEAR:
+ case tflite::BuiltinOperator_CALL:
+ case tflite::BuiltinOperator_SKIP_GRAM:
+ case tflite::BuiltinOperator_RELU1:
+ case tflite::BuiltinOperator_SPACE_TO_DEPTH:
+ FATAL("Op code %d is currently not delegated to NNAPI", builtin);
+ nn_op_type = -1; // set to invalid
+ break;
+ case tflite::BuiltinOperator_CUSTOM:
+ FATAL("Custom operations are not supported when using NNAPI.");
+ nn_op_type = -1; // set to invalid
+ break;
+ }
+
+ // Add the operation.
+ CHECK_NN(ANeuralNetworksModel_addOperation(
+ nn_model, nn_op_type, static_cast<uint32_t>(augmented_inputs.size()),
+ augmented_inputs.data(), static_cast<uint32_t>(node.outputs->size),
+ reinterpret_cast<uint32_t*>(node.outputs->data)));
+ }
+}
+
+TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
+ // TODO(aselle): This is not correct. need to handle resize invalidation.
+ if (nn_model_ && nn_compiled_model_) return kTfLiteOk;
+
+ if (!nn_model_) {
+ CHECK_NN(ANeuralNetworksModel_create(&nn_model_));
+
+ uint32_t next_id = addTensorOperands(interpreter, nn_model_);
+ AddOpsAndParams(interpreter, nn_model_, next_id);
+ CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs(
+ nn_model_, static_cast<uint32_t>(interpreter->inputs().size()),
+ reinterpret_cast<const uint32_t*>(interpreter->inputs().data()),
+ static_cast<uint32_t>(interpreter->outputs().size()),
+ reinterpret_cast<const uint32_t*>(interpreter->outputs().data())));
+ CHECK_NN(ANeuralNetworksModel_finish(nn_model_));
+ }
+ if (!nn_compiled_model_) {
+ CHECK_NN(ANeuralNetworksCompilation_create(nn_model_, &nn_compiled_model_));
+ CHECK_NN(ANeuralNetworksCompilation_finish(nn_compiled_model_));
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) {
+ if (!nn_model_) {
+ TF_LITE_ENSURE_STATUS(BuildGraph(interpreter));
+ }
+
+ ANeuralNetworksExecution* execution = nullptr;
+ CHECK_NN(ANeuralNetworksExecution_create(nn_compiled_model_, &execution));
+
+ // Currently perform deep copy of input buffer
+ for (size_t i = 0; i < interpreter->inputs().size(); i++) {
+ int input = interpreter->inputs()[i];
+ // TODO(aselle): Is this what we want or do we want input instead?
+ // TODO(aselle): This should be called setInputValue maybe to be cons.
+ TfLiteTensor* tensor = interpreter->tensor(input);
+ CHECK_NN(ANeuralNetworksExecution_setInput(
+ execution, i, nullptr, tensor->data.raw, tensor->bytes));
+ }
+ // Tell nn api where to place final data.
+ for (size_t i = 0; i < interpreter->outputs().size(); i++) {
+ int output = interpreter->outputs()[i];
+ TfLiteTensor* tensor = interpreter->tensor(output);
+ CHECK_NN(ANeuralNetworksExecution_setOutput(
+ execution, i, nullptr, tensor->data.raw, tensor->bytes));
+ }
+ // Currently use blocking compute.
+ ANeuralNetworksEvent* event = nullptr;
+ CHECK_NN(ANeuralNetworksExecution_startCompute(execution, &event));
+ CHECK_NN(ANeuralNetworksEvent_wait(event));
+ ANeuralNetworksEvent_free(event);
+ ANeuralNetworksExecution_free(execution);
+
+#if 0
+ printf("From the NN API:\n");
+ TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]);
+ if (float* data =
+ interpreter->typed_tensor<float>(interpreter->outputs()[0])) {
+ size_t num = tensor->bytes / sizeof(float);
+ for (float* p = data; p < data + num; p++) {
+ printf(" %f", *p);
+ }
+ printf("\n");
+ }
+#endif
+
+ return kTfLiteOk;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h
new file mode 100644
index 0000000000..f29aa9e18e
--- /dev/null
+++ b/tensorflow/contrib/lite/nnapi_delegate.h
@@ -0,0 +1,66 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
+
+#include "tensorflow/contrib/lite/allocation.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
+
+class ANeuralNetworsModel;
+
+namespace tflite {
+
+class NNAPIAllocation : public MMAPAllocation {
+ public:
+ NNAPIAllocation(const char* filename, ErrorReporter* error_reporter);
+ ~NNAPIAllocation();
+
+ size_t offset(const void* ptr) const {
+ auto signed_offset = reinterpret_cast<const uint8_t*>(ptr) -
+ reinterpret_cast<const uint8_t*>(mmapped_buffer_);
+
+ return static_cast<size_t>(signed_offset);
+ }
+
+ ANeuralNetworksMemory* memory() const { return handle_; }
+ bool valid() const override { return handle_ != nullptr; }
+
+ private:
+ mutable ANeuralNetworksMemory* handle_ = nullptr;
+};
+
+class NNAPIDelegate {
+ public:
+ ~NNAPIDelegate();
+
+ // Convert a tflite graph to NNAPI
+ TfLiteStatus BuildGraph(Interpreter* interpreter);
+
+ // Run
+ TfLiteStatus Invoke(Interpreter* interpreter);
+
+ private:
+ // The NN API model handle
+ ANeuralNetworksModel* nn_model_ = nullptr;
+ // The NN API compilation handle
+ ANeuralNetworksCompilation* nn_compiled_model_ = nullptr;
+};
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc
new file mode 100644
index 0000000000..1f762e6688
--- /dev/null
+++ b/tensorflow/contrib/lite/optional_debug_tools.cc
@@ -0,0 +1,108 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/optional_debug_tools.h"
+
+namespace tflite {
+
+void PrintIntVector(const std::vector<int>& v) {
+ for (const auto& it : v) {
+ printf(" %d", it);
+ }
+ printf("\n");
+}
+
+void PrintTfLiteIntVector(const TfLiteIntArray* v) {
+ if (!v) {
+ printf(" (null)");
+ return;
+ }
+ for (int k = 0; k < v->size; k++) {
+ printf(" %d", v->data[k]);
+ }
+ printf("\n");
+}
+
+const char* TensorTypeName(TfLiteType type) {
+ switch (type) {
+ case kTfLiteNoType:
+ return "kTfLiteNoType";
+ case kTfLiteFloat32:
+ return "kTfLiteFloat32";
+ case kTfLiteInt32:
+ return "kTfLiteInt32";
+ case kTfLiteUInt8:
+ return "kTfLiteUInt8";
+ case kTfLiteInt64:
+ return "kTfLiteInt64";
+ case kTfLiteString:
+ return "kTfLiteString";
+ }
+ return "(invalid)";
+}
+
+const char* AllocTypeName(TfLiteAllocationType type) {
+ switch (type) {
+ case kTfLiteMemNone:
+ return "kTfLiteMemNone";
+ case kTfLiteMmapRo:
+ return "kTfLiteMmapRo";
+ case kTfLiteDynamic:
+ return "kTfLiteDynamic";
+ case kTfLiteArenaRw:
+ return "kTfLiteArenaRw";
+ case kTfLiteArenaRwPersistent:
+ return "kTfLiteArenaRwPersistent";
+ }
+ return "(invalid)";
+}
+
+// Prints a dump of what tensors and what nodes are in the interpreter.
+void PrintInterpreterState(Interpreter* interpreter) {
+ printf("Interpreter has %d tensors and %d nodes\n",
+ interpreter->tensors_size(), interpreter->nodes_size());
+ printf("Inputs:");
+ PrintIntVector(interpreter->inputs());
+ printf("Outputs:");
+ PrintIntVector(interpreter->outputs());
+ printf("\n");
+ for (int tensor_index = 0; tensor_index < interpreter->tensors_size();
+ tensor_index++) {
+ TfLiteTensor* tensor = interpreter->tensor(tensor_index);
+ printf("Tensor %3d %10s %15s %10zu bytes (%4.1f MB) ", tensor_index,
+ TensorTypeName(tensor->type), AllocTypeName(tensor->allocation_type),
+ tensor->bytes, float(tensor->bytes) / float(1 << 20));
+ PrintTfLiteIntVector(tensor->dims);
+ printf("\n");
+ }
+
+ for (int node_index = 0; node_index < interpreter->nodes_size();
+ node_index++) {
+ const std::pair<TfLiteNode, TfLiteRegistration>* node_and_reg =
+ interpreter->node_and_registration(node_index);
+ const TfLiteNode& node = node_and_reg->first;
+ const TfLiteRegistration& reg = node_and_reg->second;
+ printf("Node %3d Operator Builtin Code %3d\n", node_index,
+ reg.builtin_code);
+ printf(" Inputs:");
+ PrintTfLiteIntVector(node.inputs);
+ printf(" Outputs:");
+ PrintTfLiteIntVector(node.outputs);
+ }
+}
+
+// Prints a dump of what tensors and what nodes are in the interpreter.
+TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter);
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/optional_debug_tools.h b/tensorflow/contrib/lite/optional_debug_tools.h
new file mode 100644
index 0000000000..54d4876095
--- /dev/null
+++ b/tensorflow/contrib/lite/optional_debug_tools.h
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Optional debugging functionality. For small sized binaries, these are not
+// needed.
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
+
+#include "tensorflow/contrib/lite/interpreter.h"
+
+namespace tflite {
+
+// Prints a dump of what tensors and what nodes are in the interpreter.
+void PrintInterpreterState(Interpreter* interpreter);
+
+// Prints a dump of what tensors and what nodes are in the interpreter.
+TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter);
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
new file mode 100644
index 0000000000..b4aa032ff8
--- /dev/null
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -0,0 +1,46 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "lite",
+ srcs = ["lite.py"],
+ # data = [
+ # "//tensorflow/contrib/lite/toco/python:toco_from_protos",
+ # ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/lite/toco:model_flags_proto_py",
+ "//tensorflow/contrib/lite/toco:toco_flags_proto_py",
+ "//tensorflow/contrib/lite/toco/python:tensorflow_wrap_toco",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_test(
+ name = "lite_test",
+ srcs = ["lite_test.py"],
+ deps = [
+ ":lite",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:session",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
new file mode 100644
index 0000000000..5e8edbb937
--- /dev/null
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -0,0 +1,199 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorFlow Lite tooling helper functionality.
+
+EXPERIMENTAL: APIs here are unstable and likely to change without notice.
+
+@@toco_convert
+@@toco_convert_protos
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import subprocess
+import tempfile
+
+from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
+from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
+from tensorflow.contrib.lite.toco.python.tensorflow_wrap_toco import TocoConvert as _toco_convert_protos
+from tensorflow.python.framework import dtypes as _dtypes
+# from tensorflow.python.platform import
+# resource_loader as _resource_loader
+
+# Enum types from the protobuf promoted to the API
+FLOAT = _toco_flags_pb2.FLOAT
+INT32 = _toco_flags_pb2.INT32
+INT64 = _toco_flags_pb2.INT64
+STRING = _toco_flags_pb2.STRING
+QUANTIZED_UINT8 = _toco_flags_pb2.QUANTIZED_UINT8
+TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF
+TFLITE = _toco_flags_pb2.TFLITE
+GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT
+
+# Currently the default mode of operation is to shell to another python process
+# to protect against crashes.
+EXPERIMENTAL_USE_TOCO_API_DIRECTLY = True
+
+# Find the toco_from_protos binary using the resource loader if using from
+# bazel, otherwise we are in a pip where console_scripts already has
+# the toco_from_protos tool.
+# toco_from_proto_bin = _resource_loader.get_path_to_datafile(
+# "../toco/python/toco_from_protos")
+# if not os.path.exists(toco_from_proto_bin):
+# toco_from_proto_bin = "toco_from_protos"
+
+
+def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
+ """Convert `input_data_str` according to model and toco parameters.
+
+ Unless you know what you are doing consider using
+ the more friendly @{tf.contrib.lite.toco_convert}}.
+
+ Args:
+ model_flags_str: Serialized proto describing model properties, see
+ `toco/model_flags.proto`.
+ toco_flags_str: Serialized proto describing conversion properties, see
+ `toco/toco_flags.proto`.
+ input_data_str: Input data in serialized form (e.g. a graphdef is common)
+ Returns:
+ Converted model in serialized form (e.g. a TFLITE model is common).
+ Raises:
+ RuntimeError: When conversion fails, an exception is raised with the error
+ message embedded.
+ """
+ # TODO(aselle): When toco does not use fatal errors for failure, we can
+ # switch this on.
+ if EXPERIMENTAL_USE_TOCO_API_DIRECTLY:
+ return _toco_convert_protos(model_flags_str, toco_flags_str, input_data_str)
+
+ # with tempfile.NamedTemporaryFile() as fp_toco, \
+ # tempfile.NamedTemporaryFile() as fp_model, \
+ # tempfile.NamedTemporaryFile() as fp_input, \
+ # tempfile.NamedTemporaryFile() as fp_output:
+ # fp_model.write(model_flags_str)
+ # fp_toco.write(toco_flags_str)
+ # fp_input.write(input_data_str)
+ # fp_model.flush()
+ # fp_toco.flush()
+ # fp_input.flush()
+
+ # cmd = [
+ # toco_from_proto_bin, fp_model.name, fp_toco.name, fp_input.name,
+ # fp_output.name
+ # ]
+ # cmdline = " ".join(cmd)
+ # proc = subprocess.Popen(
+ # cmdline,
+ # shell=True,
+ # stdout=subprocess.PIPE,
+ # stderr=subprocess.STDOUT,
+ # close_fds=True)
+ # stdout, stderr = proc.communicate()
+ # exitcode = proc.returncode
+ # if exitcode == 0:
+ # stuff = fp_output.read()
+ # return stuff
+ # else:
+ # raise RuntimeError("TOCO failed see console for info.\n%s\n%s\n" %
+ # (stdout, stderr))
+
+
+def _tensor_name(x):
+ return x.name.split(":")[0]
+
+
+def toco_convert(input_data,
+ input_tensors,
+ output_tensors,
+ inference_type=FLOAT,
+ input_format=TENSORFLOW_GRAPHDEF,
+ output_format=TFLITE,
+ quantized_input_stats=None,
+ drop_control_dependency=True):
+ """Convert a model using TOCO from `input_format` to `output_format`.
+
+ Typically this is to convert from TensorFlow GraphDef to TFLite, in which
+ case the default `input_format` and `output_format` are sufficient.
+
+ Args:
+ input_data: Input data (i.e. often `sess.graph_def`).
+ input_tensors: List of input tensors. Type and shape are computed using
+ `foo.get_shape()` and `foo.dtype`.
+ output_tensors: List of output tensors (only .name is used from this).
+ inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
+ input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF).
+ output_format: Type of data to write (currently must be TFLITE or
+ GRAPHVIZ_DOT)
+ quantized_input_stats: For each member of input_tensors the mean and
+ std deviation of training data. Only needed if `inference_type` is
+ `QUANTIZED_UINT8`.
+ drop_control_dependency: Drops control dependencies silently. This is due
+ to tf lite not supporting control dependencies.
+
+ Returns:
+ The converted data. For example if tflite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ ValueError: If the input tensor type is unknown
+ RuntimeError: If TOCO fails to convert (in which case the runtime error's
+ error text will contain the TOCO error log)
+ """
+ toco = _toco_flags_pb2.TocoFlags()
+ toco.input_format = input_format
+ toco.output_format = output_format
+ model = _model_flags_pb2.ModelFlags()
+ model.drop_control_dependency = drop_control_dependency
+ toco.inference_type = inference_type
+ for idx, input_tensor in enumerate(input_tensors):
+ if input_tensor.dtype == _dtypes.float32:
+ tflite_input_type = FLOAT
+ elif input_tensor.dtype == _dtypes.int32:
+ tflite_input_type = INT32
+ elif input_tensor.dtype == _dtypes.int64:
+ tflite_input_type = INT64
+ # TODO(aselle): Insert strings when they are available
+ else:
+ raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
+ input_tensor.dtype))
+
+ input_array = model.input_arrays.add()
+
+ if inference_type == QUANTIZED_UINT8:
+ if tflite_input_type == FLOAT:
+ tflite_input_type = QUANTIZED_UINT8
+ input_array.mean, input_array.std = quantized_input_stats[idx]
+
+ input_array.name = _tensor_name(input_tensor)
+ input_array.shape.extend(map(int, input_tensor.get_shape()))
+ toco.input_types.append(tflite_input_type)
+
+ for output_tensor in output_tensors:
+ model.output_arrays.append(_tensor_name(output_tensor))
+
+ data = toco_convert_protos(model.SerializeToString(),
+ toco.SerializeToString(),
+ input_data.SerializeToString())
+ return data
+
+
+# remove_undocumented(__name__)
+
+del os
+del subprocess
+del tempfile
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
new file mode 100644
index 0000000000..da360aeb34
--- /dev/null
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -0,0 +1,45 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorFlow Lite Python Interface: Sanity check."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.lite.python import lite
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class LiteTest(test_util.TensorFlowTestCase):
+
+ def testBasic(self):
+ in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
+ dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+ # Try running on valid graph
+ result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor])
+ self.assertTrue(result)
+ # TODO(aselle): remove tests that fail.
+ # Try running on identity graph (known fail)
+ # with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"):
+ # result = lite.toco_convert(sess.graph_def, [in_tensor], [in_tensor])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD
new file mode 100644
index 0000000000..3e04d6f34f
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/BUILD
@@ -0,0 +1,82 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_binary(
+ name = "upgrade_schema",
+ srcs = [
+ "upgrade_schema.py",
+ ],
+ data = [
+ "schema_v0.fbs",
+ "schema_v1.fbs",
+ "schema_v2.fbs",
+ "schema_v3.fbs",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:platform",
+ "@flatbuffers//:flatc",
+ ],
+)
+
+py_test(
+ name = "upgrade_schema_test",
+ size = "small",
+ srcs = ["upgrade_schema_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":upgrade_schema",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+exports_files([
+ "schema_v0.fbs",
+ "schema_v1.fbs",
+ "schema_v2.fbs",
+ "schema_v3.fbs",
+])
+
+load("//third_party/flatbuffers:build_defs.bzl", "flatbuffer_cc_library")
+
+# Generic schema for inference on device.
+flatbuffer_cc_library(
+ name = "schema_fbs",
+ srcs = ["schema.fbs"],
+)
+
+# Schema test to make sure we don't introduce backward incompatible changes
+# to schemas.
+cc_test(
+ name = "flatbuffer_compatibility_test",
+ size = "small",
+ srcs = ["flatbuffer_compatibility_test.cc"],
+ data = [
+ "schema.fbs",
+ "schema_v3.fbs",
+ ],
+ deps = [
+ "//tensorflow/core:lib_platform",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers//:flatc_library",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
new file mode 100644
index 0000000000..17ee0af8dd
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <fstream>
+#include <gtest/gtest.h>
+#include "third_party/flatbuffers/include/flatbuffers/flatc.h"
+#include "tensorflow/core/platform/platform.h"
+
+#ifdef PLATFORM_GOOGLE
+#define TFLITE_TF_PREFIX "third_party/tensorflow/"
+#else
+#define TFLITE_TF_PREFIX "tensorflow/"
+#endif
+/// Load filename `name`
+bool LoadFileRaw(const char *name, std::string *buf) {
+ std::ifstream fp(name, std::ios::binary);
+ if (!fp) {
+ fprintf(stderr, "Failed to read '%s'\n", name);
+ return false;
+ }
+ std::string s((std::istreambuf_iterator<char>(fp)),
+ std::istreambuf_iterator<char>());
+ if (s.empty()) {
+ fprintf(stderr, "Read '%s' resulted in empty\n", name);
+ return false;
+ }
+ *buf = s;
+ return true;
+}
+
+bool ParseFile(flatbuffers::Parser *parser, const std::string &filename,
+ const std::string &contents) {
+ std::vector<const char *> include_directories;
+ auto local_include_directory = flatbuffers::StripFileName(filename);
+ include_directories.push_back(local_include_directory.c_str());
+ include_directories.push_back(nullptr);
+ if (!parser->Parse(contents.c_str(), include_directories.data(),
+ filename.c_str())) {
+ fprintf(stderr, "Failed to parse flatbuffer schema '%s'\n",
+ contents.c_str());
+ return false;
+ }
+ return true;
+}
+
+// Checks to make sure current schema in current code does not cause an
+// incompatibility.
+TEST(SchemaTest, TestCompatibility) {
+ // Read file contents of schemas into strings
+ // TODO(aselle): Need a reliable way to load files.
+ std::string base_contents, current_contents;
+ const char *base_filename =
+ TFLITE_TF_PREFIX "contrib/lite/schema/schema_v3.fbs";
+ const char *current_filename =
+ TFLITE_TF_PREFIX "contrib/lite/schema/schema.fbs";
+
+ ASSERT_TRUE(LoadFileRaw(base_filename, &base_contents));
+ ASSERT_TRUE(LoadFileRaw(current_filename, &current_contents));
+ // Parse the schemas
+ flatbuffers::Parser base_parser, current_parser;
+ std::vector<const char *> include_directories;
+ ASSERT_TRUE(ParseFile(&base_parser, base_filename, base_contents));
+ ASSERT_TRUE(ParseFile(&current_parser, current_filename, current_contents));
+ // Check that the schemas conform and fail if they don't
+ auto err = current_parser.ConformTo(base_parser);
+ if (!err.empty()) {
+ fprintf(stderr,
+ "Schemas don't conform:\n%s\n"
+ "In other words some change you made means that new parsers can't"
+ "parse old files.\n",
+ err.c_str());
+ FAIL();
+ }
+}
+
+int main(int argc, char **argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
new file mode 100644
index 0000000000..ddb2ab792c
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -0,0 +1,346 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Revision History
+// Version 0: Initial version.
+// Version 1: Add subgraphs to schema.
+// Version 2: Rename operators to conform to NN API.
+// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers.
+
+namespace tflite;
+
+// This corresponds to the version.
+file_identifier "TFL3";
+// File extension of any written files.
+file_extension "tflite";
+
+// The type of data stored in a tensor.
+enum TensorType : byte {
+ FLOAT32 = 0,
+ FLOAT16 = 1,
+ INT32 = 2,
+ UINT8 = 3,
+ INT64 = 4,
+ STRING = 5,
+}
+
+// Parameters for converting a quantized tensor back to float. Given a
+// quantized value q, the corresponding float value f should be:
+// f = scale * (q - zero_point)
+table QuantizationParameters {
+ min:[float]; // For importing back into tensorflow.
+ max:[float]; // For importing back into tensorflow.
+ scale:[float];
+ zero_point:[long];
+}
+
+table Tensor {
+ // The tensor shape. The meaning of each entry is operator-specific but
+ // builtin ops use: [batch size, number of channels, height, width] (That's
+ // Tensorflow's NCHW).
+ shape:[int];
+ type:TensorType;
+ // An index that refers to the buffers table at the root of the model. Or,
+ // if there is no data buffer associated (i.e. intermediate results), then
+ // this is 0 (which refers to an always existant empty buffer).
+ //
+ // The data_buffer itself is an opaque container, with the assumption that the
+ // target device is little-endian. In addition, all builtin operators assume
+ // the memory is ordered such that if `shape` is [4, 3, 2], then index
+ // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k].
+ buffer:uint;
+ name:string; // For debugging and importing back into tensorflow.
+ quantization:QuantizationParameters; // Optional.
+}
+
+// A list of builtin operators. Builtin operators a slighlty faster than custom
+// ones, but not by much. Moreover, while custom operators accept an opaque
+// object containing configuration parameters, builtins have a predetermined
+// set of acceptable options.
+enum BuiltinOperator : byte {
+ ADD = 0,
+ AVERAGE_POOL_2D = 1,
+ CONCATENATION = 2,
+ CONV_2D = 3,
+ DEPTHWISE_CONV_2D = 4,
+ // DEPTH_TO_SPACE = 5,
+ // DEQUANTIZE = 6,
+ EMBEDDING_LOOKUP = 7,
+ // FLOOR = 8,
+ FULLY_CONNECTED = 9,
+ HASHTABLE_LOOKUP = 10,
+ L2_NORMALIZATION = 11,
+ L2_POOL_2D = 12,
+ LOCAL_RESPONSE_NORMALIZATION = 13,
+ LOGISTIC = 14,
+ LSH_PROJECTION = 15,
+ LSTM = 16,
+ MAX_POOL_2D = 17,
+ MUL = 18,
+ RELU = 19,
+ RELU1 = 20,
+ RELU6 = 21,
+ RESHAPE = 22,
+ RESIZE_BILINEAR = 23,
+ RNN = 24,
+ SOFTMAX = 25,
+ SPACE_TO_DEPTH = 26,
+ SVDF = 27,
+ TANH = 28,
+ // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS
+ CONCAT_EMBEDDINGS = 29,
+ SKIP_GRAM = 30,
+ CALL = 31,
+ CUSTOM = 32,
+ EMBEDDING_LOOKUP_SPARSE = 33,
+}
+
+// Options for the builtin operators.
+union BuiltinOptions {
+ Conv2DOptions,
+ DepthwiseConv2DOptions,
+ ConcatEmbeddingsOptions,
+ LSHProjectionOptions,
+ Pool2DOptions,
+ SVDFOptions,
+ RNNOptions,
+ FullyConnectedOptions,
+ SoftmaxOptions,
+ ConcatenationOptions,
+ AddOptions,
+ L2NormOptions,
+ LocalResponseNormalizationOptions,
+ LSTMOptions,
+ ResizeBilinearOptions,
+ CallOptions,
+ ReshapeOptions,
+ SkipGramOptions,
+ SpaceToDepthOptions,
+ EmbeddingLookupSparseOptions,
+ MulOptions,
+}
+
+enum Padding : byte { SAME, VALID }
+
+enum ActivationFunctionType : byte {
+ NONE = 0,
+ RELU = 1,
+ RELU1 = 2,
+ RELU6 = 3,
+ TANH = 4,
+ SIGN_BIT = 5,
+}
+
+table Conv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table Pool2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ filter_width:int;
+ filter_height:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table DepthwiseConv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ depth_multiplier:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table ConcatEmbeddingsOptions {
+ num_channels:int;
+ num_columns_per_channel:[int];
+ embedding_dim_per_channel:[int]; // This could be inferred from parameters.
+}
+
+enum LSHProjectionType: byte {
+ UNKNOWN = 0,
+ SPARSE = 1,
+ DENSE = 2,
+}
+
+table LSHProjectionOptions {
+ type: LSHProjectionType;
+}
+
+table SVDFOptions {
+ rank:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow RNNCell.
+table RNNOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
+table FullyConnectedOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table SoftmaxOptions {
+ beta: float;
+}
+
+// An implementation of TensorFlow concat.
+table ConcatenationOptions {
+ axis:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table AddOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table MulOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table L2NormOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table LocalResponseNormalizationOptions {
+ radius:int;
+ bias:float;
+ alpha:float;
+ beta:float;
+}
+
+// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
+table LSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+}
+
+table ResizeBilinearOptions {
+ new_height:int;
+ new_width:int;
+}
+
+// A call operation options
+table CallOptions {
+ // The subgraph index that needs to be called.
+ subgraph:uint;
+}
+
+table ReshapeOptions {
+ new_shape:[int];
+}
+
+table SkipGramOptions {
+ ngram_size: int;
+ max_skip_size: int;
+ include_all_ngrams: bool;
+}
+
+table SpaceToDepthOptions {
+ block_size: int;
+}
+
+enum CombinerType : byte {
+ SUM = 0,
+ MEAN = 1,
+ SQRTN = 2,
+}
+
+table EmbeddingLookupSparseOptions {
+ combiner:CombinerType;
+}
+
+// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
+// builtin, or a string if the operator is custom.
+table OperatorCode {
+ builtin_code:BuiltinOperator;
+ custom_code:string;
+}
+
+enum CustomOptionsFormat : byte {
+ FLEXBUFFERS = 0,
+}
+
+// An operator takes tensors as inputs and outputs. The type of operation being
+// performed is determined by an index into the list of valid OperatorCodes,
+// while the specifics of each operations is configured using builtin_options
+// or custom_options.
+table Operator {
+ // Index into the operator_codes array. Using an integer here avoids
+ // complicate map lookups.
+ opcode_index:uint;
+
+ // Optional input and output tensors are indicated by -1.
+ inputs:[int];
+ outputs:[int];
+
+ builtin_options:BuiltinOptions;
+ custom_options:[ubyte];
+ custom_options_format:CustomOptionsFormat;
+}
+
+// The root type, defining a model.
+table SubGraph {
+ // A list of all tensors used in this model.
+ tensors:[Tensor];
+
+ // Indices of the input tensors.
+ inputs:[int];
+
+ // Indices of the output tensors.
+ outputs:[int];
+
+ // All operators, in execution order.
+ operators:[Operator];
+
+ // Name of subgraph (used for debugging).
+ name:string;
+}
+
+// Table of raw data buffers (used for constant tensors). Referenced by tensors
+// by index.
+table Buffer {
+ data:[ubyte];
+}
+
+table Model {
+ // Version of the schema.
+ version:uint;
+
+ // A list of all operator codes used in this model. This is
+ // kept in order because operators carry an index into this
+ // vector.
+ operator_codes:[OperatorCode];
+
+ // All the subgraphs of the model. The 0th is assumed to be the main
+ // model.
+ subgraphs:[SubGraph];
+
+ // A description of the model.
+ description:string;
+
+ // Buffers of the model
+ buffers:[Buffer];
+
+}
+
+root_type Model;
+
diff --git a/tensorflow/contrib/lite/schema/schema_v0.fbs b/tensorflow/contrib/lite/schema/schema_v0.fbs
new file mode 100644
index 0000000000..852ea988f3
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/schema_v0.fbs
@@ -0,0 +1,247 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+namespace tflite;
+
+// The type of data stored in a tensor.
+enum TensorType : byte {
+ FLOAT32 = 0,
+ FLOAT16 = 1,
+ INT32 = 2,
+ UINT8 = 3,
+ INT64 = 4,
+}
+
+// Parameters for converting a quantized tensor back to float. Given a
+// quantized value q, the corresponding float value f should be:
+// f = scale * (q - zero_point)
+table QuantizationParameters {
+ min:[float]; // For importing back into tensorflow.
+ max:[float]; // For importing back into tensorflow.
+ scale:[float];
+ zero_point:[long];
+}
+
+table Tensor {
+ // The tensor shape. The meaning of each entry is operator-specific but
+ // builtin ops use: [batch size, number of channels, height, width] (That's
+ // Tensorflow's NCHW).
+ shape:[int];
+ type:TensorType;
+ // The data_buffer is an opaque container, with the assumption that the
+ // target device is little-endian. In addition, all builtin operators assume
+ // the memory is ordered such that if `shape` is [4, 3, 2], then index
+ // [i, j, k] maps to data_buffer[i*4*3 + j*3 + k].
+ data_buffer:[ubyte];
+ name:string; // For debugging and importing back into tensorflow.
+ quantization:QuantizationParameters; // Optional.
+}
+
+// A list of builtin operators. Builtin operators a slighlty faster than custom
+// ones, but not by much. Moreover, while custom operators accept an opaque
+// object containing configuration parameters, builtins have a predetermined
+// set of acceptable options.
+enum BuiltinOperator : byte {
+ CUSTOM = 0,
+ CONVOLUTION = 1,
+ DEPTHWISE_CONVOLUTION = 2,
+ CONCAT_EMBEDDINGS = 3,
+ LSH_PROJECTION = 4,
+ TANH = 5,
+ RELU = 6,
+ AVERAGE_POOL = 7,
+ MAX_POOL = 8,
+ L2_POOL = 9,
+ SIGMOID = 10,
+ SVDF = 11,
+ BasicRNN = 12,
+ RELU6 = 13,
+ EMBEDDING_LOOKUP = 14,
+ FULLY_CONNECTED = 15,
+ HASHTABLE_LOOKUP = 16,
+ SOFTMAX = 17,
+ CONCATENATION = 18,
+ LSTM = 19,
+ ADD = 20,
+ L2NORM = 21,
+ LOCAL_RESPONSE_NORM = 22,
+ RESIZE_BILINEAR = 23,
+}
+
+// Options for the builtin operators.
+union BuiltinOptions {
+ ConvolutionOptions,
+ DepthwiseConvolutionOptions,
+ ConcatEmbeddingsOptions,
+ LSHProjectionOptions,
+ PoolOptions,
+ SVDFOptions,
+ BasicRNNOptions,
+ FullyConnectedOptions,
+ SoftmaxOptions,
+ ConcatenationOptions,
+ AddOptions,
+ L2NormOptions,
+ LocalResponseNormOptions,
+ LSTMOptions,
+ ResizeBilinearOptions,
+}
+
+enum Padding : byte { SAME, VALID }
+
+enum ActivationFunctionType : byte {
+ NONE = 0,
+ RELU = 1,
+ RELU1 = 2,
+ RELU6 = 3,
+ TANH = 4,
+ SIGN_BIT = 5,
+}
+
+table ConvolutionOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table PoolOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ filter_width:int;
+ filter_height:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table DepthwiseConvolutionOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ depth_multiplier:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table ConcatEmbeddingsOptions {
+ num_channels:int;
+ num_columns_per_channel:[int];
+ embedding_dim_per_channel:[int]; // This could be inferred from parameters.
+}
+
+enum LSHProjectionType: byte {
+ UNKNOWN = 0,
+ SPARSE = 1,
+ DENSE = 2,
+}
+
+table LSHProjectionOptions {
+ type: LSHProjectionType;
+}
+
+table SVDFOptions {
+ rank:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow BasicRNNCell.
+table BasicRNNOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
+table FullyConnectedOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table SoftmaxOptions {
+ beta: float;
+}
+
+// An implementation of TensorFlow concat.
+table ConcatenationOptions {
+ axis:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table AddOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table L2NormOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table LocalResponseNormOptions {
+ radius:int;
+ bias:float;
+ alpha:float;
+ beta:float;
+}
+
+// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
+table LSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+}
+
+table ResizeBilinearOptions {
+ new_height:int;
+ new_width:int;
+}
+
+// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
+// builtin, or a string if the operator is custom.
+table OperatorCode {
+ builtin_code:BuiltinOperator;
+ custom_code:string;
+}
+
+// An operator takes tensors as inputs and outputs. The type of operation being
+// performed is determined by an index into the list of valid OperatorCodes,
+// while the specifics of each operations is configured using builtin_options
+// or custom_options.
+table Operator {
+ // Index into the operator_codes array. Using an integer here avoids
+ // complicate map lookups.
+ opcode_index:int;
+
+ inputs:[int];
+ outputs:[int];
+
+ builtin_options:BuiltinOptions;
+ custom_options:[ubyte];
+}
+
+// The root type, defining a model.
+table Model {
+ // A list of all tensors used in this model.
+ tensors:[Tensor];
+
+ // Indices of the input tensors.
+ inputs:[int];
+
+ // Indices of the output tensors.
+ outputs:[int];
+
+ // A list of all operator codes used in this model. This is
+ // kept in order because operators carry an index into this
+ // vector.
+ operator_codes:[OperatorCode];
+
+ // All operators, in execution order.
+ operators:[Operator];
+}
+
+root_type Model;
diff --git a/tensorflow/contrib/lite/schema/schema_v1.fbs b/tensorflow/contrib/lite/schema/schema_v1.fbs
new file mode 100644
index 0000000000..06cd9408ed
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/schema_v1.fbs
@@ -0,0 +1,295 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Revision History
+// Version 0: Initial version.
+// Version 1: Add subgraphs to schema.
+
+namespace tflite;
+
+// The type of data stored in a tensor.
+enum TensorType : byte {
+ FLOAT32 = 0,
+ FLOAT16 = 1,
+ INT32 = 2,
+ UINT8 = 3,
+ INT64 = 4,
+ STRING = 5,
+}
+
+// Parameters for converting a quantized tensor back to float. Given a
+// quantized value q, the corresponding float value f should be:
+// f = scale * (q - zero_point)
+table QuantizationParameters {
+ min:[float]; // For importing back into tensorflow.
+ max:[float]; // For importing back into tensorflow.
+ scale:[float];
+ zero_point:[long];
+}
+
+table Tensor {
+ // The tensor shape. The meaning of each entry is operator-specific but
+ // builtin ops use: [batch size, number of channels, height, width] (That's
+ // Tensorflow's NCHW).
+ shape:[int];
+ type:TensorType;
+ // The data_buffer is an opaque container, with the assumption that the
+ // target device is little-endian. In addition, all builtin operators assume
+ // the memory is ordered such that if `shape` is [4, 3, 2], then index
+ // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k].
+ data_buffer:[ubyte];
+ name:string; // For debugging and importing back into tensorflow.
+ quantization:QuantizationParameters; // Optional.
+}
+
+// A list of builtin operators. Builtin operators a slighlty faster than custom
+// ones, but not by much. Moreover, while custom operators accept an opaque
+// object containing configuration parameters, builtins have a predetermined
+// set of acceptable options.
+enum BuiltinOperator : byte {
+ CUSTOM = 0,
+ CONVOLUTION = 1,
+ DEPTHWISE_CONVOLUTION = 2,
+ CONCAT_EMBEDDINGS = 3,
+ LSH_PROJECTION = 4,
+ TANH = 5,
+ RELU = 6,
+ AVERAGE_POOL = 7,
+ MAX_POOL = 8,
+ L2_POOL = 9,
+ SIGMOID = 10,
+ SVDF = 11,
+ BasicRNN = 12,
+ RELU6 = 13,
+ EMBEDDING_LOOKUP = 14,
+ FULLY_CONNECTED = 15,
+ HASHTABLE_LOOKUP = 16,
+ SOFTMAX = 17,
+ CONCATENATION = 18,
+ LSTM = 19,
+ ADD = 20,
+ L2NORM = 21,
+ LOCAL_RESPONSE_NORM = 22,
+ RESIZE_BILINEAR = 23,
+ CALL = 24,
+ RESHAPE = 25,
+ SKIP_GRAM = 26,
+ SPACE_TO_DEPTH = 27,
+}
+
+// Options for the builtin operators.
+union BuiltinOptions {
+ ConvolutionOptions,
+ DepthwiseConvolutionOptions,
+ ConcatEmbeddingsOptions,
+ LSHProjectionOptions,
+ PoolOptions,
+ SVDFOptions,
+ BasicRNNOptions,
+ FullyConnectedOptions,
+ SoftmaxOptions,
+ ConcatenationOptions,
+ AddOptions,
+ L2NormOptions,
+ LocalResponseNormOptions,
+ LSTMOptions,
+ ResizeBilinearOptions,
+ CallOptions,
+ ReshapeOptions,
+ SkipGramOptions,
+ SpaceToDepthOptions,
+}
+
+enum Padding : byte { SAME, VALID }
+
+enum ActivationFunctionType : byte {
+ NONE = 0,
+ RELU = 1,
+ RELU1 = 2,
+ RELU6 = 3,
+ TANH = 4,
+ SIGN_BIT = 5,
+}
+
+table ConvolutionOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table PoolOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ filter_width:int;
+ filter_height:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table DepthwiseConvolutionOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ depth_multiplier:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table ConcatEmbeddingsOptions {
+ num_channels:int;
+ num_columns_per_channel:[int];
+ embedding_dim_per_channel:[int]; // This could be inferred from parameters.
+}
+
+enum LSHProjectionType: byte {
+ UNKNOWN = 0,
+ SPARSE = 1,
+ DENSE = 2,
+}
+
+table LSHProjectionOptions {
+ type: LSHProjectionType;
+}
+
+table SVDFOptions {
+ rank:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow BasicRNNCell.
+table BasicRNNOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
+table FullyConnectedOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table SoftmaxOptions {
+ beta: float;
+}
+
+// An implementation of TensorFlow concat.
+table ConcatenationOptions {
+ axis:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table AddOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table L2NormOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table LocalResponseNormOptions {
+ radius:int;
+ bias:float;
+ alpha:float;
+ beta:float;
+}
+
+// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
+table LSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+}
+
+table ResizeBilinearOptions {
+ new_height:int;
+ new_width:int;
+}
+
+// A call operation options
+table CallOptions {
+ // The subgraph index that needs to be called.
+ subgraph:int;
+}
+
+table ReshapeOptions {
+ new_shape:[int];
+}
+
+table SkipGramOptions {
+ ngram_size: int;
+ max_skip_size: int;
+ include_all_ngrams: bool;
+}
+
+table SpaceToDepthOptions {
+ block_size: int;
+}
+
+// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
+// builtin, or a string if the operator is custom.
+table OperatorCode {
+ builtin_code:BuiltinOperator;
+ custom_code:string;
+}
+
+// An operator takes tensors as inputs and outputs. The type of operation being
+// performed is determined by an index into the list of valid OperatorCodes,
+// while the specifics of each operations is configured using builtin_options
+// or custom_options.
+table Operator {
+ // Index into the operator_codes array. Using an integer here avoids
+ // complicate map lookups.
+ opcode_index:int;
+
+ inputs:[int];
+ outputs:[int];
+
+ builtin_options:BuiltinOptions;
+ custom_options:[ubyte];
+}
+
+// The root type, defining a model.
+table SubGraph {
+ // A list of all tensors used in this model.
+ tensors:[Tensor];
+
+ // Indices of the input tensors.
+ inputs:[int];
+
+ // Indices of the output tensors.
+ outputs:[int];
+
+ // All operators, in execution order.
+ operators:[Operator];
+
+ // Name of subgraph (used for debugging).
+ name:string;
+}
+
+table Model {
+ // Version of the schema.
+ version:int;
+
+ // A list of all operator codes used in this model. This is
+ // kept in order because operators carry an index into this
+ // vector.
+ operator_codes:[OperatorCode];
+
+ // All the subgraphs of the model. The 0th is assumed to be the main
+ // model.
+ subgraphs:[SubGraph];
+
+ // A description of the model.
+ description:string;
+}
+
+root_type Model;
diff --git a/tensorflow/contrib/lite/schema/schema_v2.fbs b/tensorflow/contrib/lite/schema/schema_v2.fbs
new file mode 100644
index 0000000000..96731c8aae
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/schema_v2.fbs
@@ -0,0 +1,303 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Revision History
+// Version 0: Initial version.
+// Version 1: Add subgraphs to schema.
+// Version 2: Rename operators to conform to NN API.
+
+namespace tflite;
+
+// The type of data stored in a tensor.
+enum TensorType : byte {
+ FLOAT32 = 0,
+ FLOAT16 = 1,
+ INT32 = 2,
+ UINT8 = 3,
+ INT64 = 4,
+ STRING = 5,
+}
+
+// Parameters for converting a quantized tensor back to float. Given a
+// quantized value q, the corresponding float value f should be:
+// f = scale * (q - zero_point)
+table QuantizationParameters {
+ min:[float]; // For importing back into tensorflow.
+ max:[float]; // For importing back into tensorflow.
+ scale:[float];
+ zero_point:[long];
+}
+
+table Tensor {
+ // The tensor shape. The meaning of each entry is operator-specific but
+ // builtin ops use: [batch size, number of channels, height, width] (That's
+ // Tensorflow's NCHW).
+ shape:[int];
+ type:TensorType;
+ // The data_buffer is an opaque container, with the assumption that the
+ // target device is little-endian. In addition, all builtin operators assume
+ // the memory is ordered such that if `shape` is [4, 3, 2], then index
+ // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k].
+ data_buffer:[ubyte];
+ name:string; // For debugging and importing back into tensorflow.
+ quantization:QuantizationParameters; // Optional.
+}
+
+// A list of builtin operators. Builtin operators a slighlty faster than custom
+// ones, but not by much. Moreover, while custom operators accept an opaque
+// object containing configuration parameters, builtins have a predetermined
+// set of acceptable options.
+enum BuiltinOperator : byte {
+ ADD = 0,
+ AVERAGE_POOL_2D = 1,
+ CONCATENATION = 2,
+ CONV_2D = 3,
+ DEPTHWISE_CONV_2D = 4,
+ // DEPTH_TO_SPACE = 5,
+ // DEQUANTIZE = 6,
+ EMBEDDING_LOOKUP = 7,
+ // FLOOR = 8,
+ FULLY_CONNECTED = 9,
+ HASHTABLE_LOOKUP = 10,
+ L2_NORMALIZATION = 11,
+ L2_POOL_2D = 12,
+ LOCAL_RESPONSE_NORMALIZATION = 13,
+ LOGISTIC = 14,
+ LSH_PROJECTION = 15,
+ LSTM = 16,
+ MAX_POOL_2D = 17,
+ // MUL = 18,
+ RELU = 19,
+ // RELU1=20,
+ RELU6 = 21,
+ RESHAPE = 22,
+ RESIZE_BILINEAR = 23,
+ RNN = 24,
+ SOFTMAX = 25,
+ SPACE_TO_DEPTH = 26,
+ SVDF = 27,
+ TANH = 28,
+ // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS
+ CONCAT_EMBEDDINGS = 29,
+ SKIP_GRAM = 30,
+ CALL = 31,
+ CUSTOM = 32,
+
+}
+
+// Options for the builtin operators.
+union BuiltinOptions {
+ Conv2DOptions,
+ DepthwiseConv2DOptions,
+ ConcatEmbeddingsOptions,
+ LSHProjectionOptions,
+ Pool2DOptions,
+ SVDFOptions,
+ RNNOptions,
+ FullyConnectedOptions,
+ SoftmaxOptions,
+ ConcatenationOptions,
+ AddOptions,
+ L2NormOptions,
+ LocalResponseNormalizationOptions,
+ LSTMOptions,
+ ResizeBilinearOptions,
+ CallOptions,
+ ReshapeOptions,
+ SkipGramOptions,
+ SpaceToDepthOptions,
+}
+
+enum Padding : byte { SAME, VALID }
+
+enum ActivationFunctionType : byte {
+ NONE = 0,
+ RELU = 1,
+ RELU1 = 2,
+ RELU6 = 3,
+ TANH = 4,
+ SIGN_BIT = 5,
+}
+
+table Conv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table Pool2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ filter_width:int;
+ filter_height:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table DepthwiseConv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ depth_multiplier:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table ConcatEmbeddingsOptions {
+ num_channels:int;
+ num_columns_per_channel:[int];
+ embedding_dim_per_channel:[int]; // This could be inferred from parameters.
+}
+
+enum LSHProjectionType: byte {
+ UNKNOWN = 0,
+ SPARSE = 1,
+ DENSE = 2,
+}
+
+table LSHProjectionOptions {
+ type: LSHProjectionType;
+}
+
+table SVDFOptions {
+ rank:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow RNNCell.
+table RNNOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
+table FullyConnectedOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table SoftmaxOptions {
+ beta: float;
+}
+
+// An implementation of TensorFlow concat.
+table ConcatenationOptions {
+ axis:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table AddOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table L2NormOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table LocalResponseNormalizationOptions {
+ radius:int;
+ bias:float;
+ alpha:float;
+ beta:float;
+}
+
+// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
+table LSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+}
+
+table ResizeBilinearOptions {
+ new_height:int;
+ new_width:int;
+}
+
+// A call operation options
+table CallOptions {
+ // The subgraph index that needs to be called.
+ subgraph:int;
+}
+
+table ReshapeOptions {
+ new_shape:[int];
+}
+
+table SkipGramOptions {
+ ngram_size: int;
+ max_skip_size: int;
+ include_all_ngrams: bool;
+}
+
+table SpaceToDepthOptions {
+ block_size: int;
+}
+
+// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
+// builtin, or a string if the operator is custom.
+table OperatorCode {
+ builtin_code:BuiltinOperator;
+ custom_code:string;
+}
+
+// An operator takes tensors as inputs and outputs. The type of operation being
+// performed is determined by an index into the list of valid OperatorCodes,
+// while the specifics of each operations is configured using builtin_options
+// or custom_options.
+table Operator {
+ // Index into the operator_codes array. Using an integer here avoids
+ // complicate map lookups.
+ opcode_index:int;
+
+ inputs:[int];
+ outputs:[int];
+
+ builtin_options:BuiltinOptions;
+ custom_options:[ubyte];
+}
+
+// The root type, defining a model.
+table SubGraph {
+ // A list of all tensors used in this model.
+ tensors:[Tensor];
+
+ // Indices of the input tensors.
+ inputs:[int];
+
+ // Indices of the output tensors.
+ outputs:[int];
+
+ // All operators, in execution order.
+ operators:[Operator];
+
+ // Name of subgraph (used for debugging).
+ name:string;
+}
+
+table Model {
+ // Version of the schema.
+ version:int;
+
+ // A list of all operator codes used in this model. This is
+ // kept in order because operators carry an index into this
+ // vector.
+ operator_codes:[OperatorCode];
+
+ // All the subgraphs of the model. The 0th is assumed to be the main
+ // model.
+ subgraphs:[SubGraph];
+
+ // A description of the model.
+ description:string;
+}
+
+root_type Model;
diff --git a/tensorflow/contrib/lite/schema/schema_v3.fbs b/tensorflow/contrib/lite/schema/schema_v3.fbs
new file mode 100644
index 0000000000..cedefe08f3
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/schema_v3.fbs
@@ -0,0 +1,326 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Revision History
+// Version 0: Initial version.
+// Version 1: Add subgraphs to schema.
+// Version 2: Rename operators to conform to NN API.
+// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers.
+
+namespace tflite;
+
+// This corresponds to the version (4).
+file_identifier "TFL3";
+// File extension of any written files.
+file_extension "tflite";
+
+// The type of data stored in a tensor.
+enum TensorType : byte {
+ FLOAT32 = 0,
+ FLOAT16 = 1,
+ INT32 = 2,
+ UINT8 = 3,
+ INT64 = 4,
+ STRING = 5,
+}
+
+// Parameters for converting a quantized tensor back to float. Given a
+// quantized value q, the corresponding float value f should be:
+// f = scale * (q - zero_point)
+table QuantizationParameters {
+ min:[float]; // For importing back into tensorflow.
+ max:[float]; // For importing back into tensorflow.
+ scale:[float];
+ zero_point:[long];
+}
+
+table Tensor {
+ // The tensor shape. The meaning of each entry is operator-specific but
+ // builtin ops use: [batch size, number of channels, height, width] (That's
+ // Tensorflow's NCHW).
+ shape:[int];
+ type:TensorType;
+ // An index that refers to the buffers table at the root of the model. Or,
+ // if there is no data buffer associated (i.e. intermediate results), then
+ // this is 0 (which refers to an always existant empty buffer).
+ //
+ // The data_buffer itself is an opaque container, with the assumption that the
+ // target device is little-endian. In addition, all builtin operators assume
+ // the memory is ordered such that if `shape` is [4, 3, 2], then index
+ // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k].
+ buffer:uint;
+ name:string; // For debugging and importing back into tensorflow.
+ quantization:QuantizationParameters; // Optional.
+}
+
+// A list of builtin operators. Builtin operators a slighlty faster than custom
+// ones, but not by much. Moreover, while custom operators accept an opaque
+// object containing configuration parameters, builtins have a predetermined
+// set of acceptable options.
+enum BuiltinOperator : byte {
+ ADD = 0,
+ AVERAGE_POOL_2D = 1,
+ CONCATENATION = 2,
+ CONV_2D = 3,
+ DEPTHWISE_CONV_2D = 4,
+ // DEPTH_TO_SPACE = 5,
+ // DEQUANTIZE = 6,
+ EMBEDDING_LOOKUP = 7,
+ // FLOOR = 8,
+ FULLY_CONNECTED = 9,
+ HASHTABLE_LOOKUP = 10,
+ L2_NORMALIZATION = 11,
+ L2_POOL_2D = 12,
+ LOCAL_RESPONSE_NORMALIZATION = 13,
+ LOGISTIC = 14,
+ LSH_PROJECTION = 15,
+ LSTM = 16,
+ MAX_POOL_2D = 17,
+ // MUL = 18,
+ RELU = 19,
+ // RELU1=20,
+ RELU6 = 21,
+ RESHAPE = 22,
+ RESIZE_BILINEAR = 23,
+ RNN = 24,
+ SOFTMAX = 25,
+ SPACE_TO_DEPTH = 26,
+ SVDF = 27,
+ TANH = 28,
+ // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS
+ CONCAT_EMBEDDINGS = 29,
+ SKIP_GRAM = 30,
+ CALL = 31,
+ CUSTOM = 32,
+
+}
+
+// Options for the builtin operators.
+union BuiltinOptions {
+ Conv2DOptions,
+ DepthwiseConv2DOptions,
+ ConcatEmbeddingsOptions,
+ LSHProjectionOptions,
+ Pool2DOptions,
+ SVDFOptions,
+ RNNOptions,
+ FullyConnectedOptions,
+ SoftmaxOptions,
+ ConcatenationOptions,
+ AddOptions,
+ L2NormOptions,
+ LocalResponseNormalizationOptions,
+ LSTMOptions,
+ ResizeBilinearOptions,
+ CallOptions,
+ ReshapeOptions,
+ SkipGramOptions,
+ SpaceToDepthOptions,
+}
+
+enum Padding : byte { SAME, VALID }
+
+enum ActivationFunctionType : byte {
+ NONE = 0,
+ RELU = 1,
+ RELU1 = 2,
+ RELU6 = 3,
+ TANH = 4,
+ SIGN_BIT = 5,
+}
+
+table Conv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table Pool2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ filter_width:int;
+ filter_height:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table DepthwiseConv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ depth_multiplier:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table ConcatEmbeddingsOptions {
+ num_channels:int;
+ num_columns_per_channel:[int];
+ embedding_dim_per_channel:[int]; // This could be inferred from parameters.
+}
+
+enum LSHProjectionType: byte {
+ UNKNOWN = 0,
+ SPARSE = 1,
+ DENSE = 2,
+}
+
+table LSHProjectionOptions {
+ type: LSHProjectionType;
+}
+
+table SVDFOptions {
+ rank:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow RNNCell.
+table RNNOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
+table FullyConnectedOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table SoftmaxOptions {
+ beta: float;
+}
+
+// An implementation of TensorFlow concat.
+table ConcatenationOptions {
+ axis:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table AddOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table L2NormOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table LocalResponseNormalizationOptions {
+ radius:int;
+ bias:float;
+ alpha:float;
+ beta:float;
+}
+
+// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
+table LSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+}
+
+table ResizeBilinearOptions {
+ new_height:int;
+ new_width:int;
+}
+
+// A call operation options
+table CallOptions {
+ // The subgraph index that needs to be called.
+ subgraph:uint;
+}
+
+table ReshapeOptions {
+ new_shape:[int];
+}
+
+table SkipGramOptions {
+ ngram_size: int;
+ max_skip_size: int;
+ include_all_ngrams: bool;
+}
+
+table SpaceToDepthOptions {
+ block_size: int;
+}
+
+// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
+// builtin, or a string if the operator is custom.
+table OperatorCode {
+ builtin_code:BuiltinOperator;
+ custom_code:string;
+}
+
+// An operator takes tensors as inputs and outputs. The type of operation being
+// performed is determined by an index into the list of valid OperatorCodes,
+// while the specifics of each operations is configured using builtin_options
+// or custom_options.
+table Operator {
+ // Index into the operator_codes array. Using an integer here avoids
+ // complicate map lookups.
+ opcode_index:uint;
+
+ inputs:[int];
+ outputs:[int];
+
+ builtin_options:BuiltinOptions;
+ custom_options:[ubyte];
+}
+
+// The root type, defining a model.
+table SubGraph {
+ // A list of all tensors used in this model.
+ tensors:[Tensor];
+
+ // Indices of the input tensors.
+ inputs:[int];
+
+ // Indices of the output tensors.
+ outputs:[int];
+
+ // All operators, in execution order.
+ operators:[Operator];
+
+ // Name of subgraph (used for debugging).
+ name:string;
+}
+
+// Table of raw data buffers (used for constant tensors). Referenced by tensors
+// by index.
+table Buffer {
+ data:[ubyte];
+}
+
+table Model {
+ // Version of the schema.
+ version:uint;
+
+ // A list of all operator codes used in this model. This is
+ // kept in order because operators carry an index into this
+ // vector.
+ operator_codes:[OperatorCode];
+
+ // All the subgraphs of the model. The 0th is assumed to be the main
+ // model.
+ subgraphs:[SubGraph];
+
+ // A description of the model.
+ description:string;
+
+ // Buffers of the model.
+ // NOTE: It is required that the first entry in here is always an empty
+ // buffer. This is so that the default buffer index of zero in Tensor
+ // will always refer to a valid empty buffer.
+ buffers:[Buffer];
+
+}
+
+root_type Model;
diff --git a/tensorflow/contrib/lite/schema/upgrade_schema.py b/tensorflow/contrib/lite/schema/upgrade_schema.py
new file mode 100644
index 0000000000..320c7138d2
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/upgrade_schema.py
@@ -0,0 +1,341 @@
+# ==============================================================================
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Upgrade script to move from pre-release schema to new schema.
+
+Usage examples:
+
+bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.json out.json
+bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.bin out.bin
+bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.bin out.json
+bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.json out.bin
+bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.tflite out.tflite
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import contextlib
+import json
+import os
+import shutil
+import subprocess
+import sys
+import tempfile
+
+import tensorflow as tf
+from tensorflow.python.platform import resource_loader
+
+parser = argparse.ArgumentParser(
+ description="Script to move TFLite models from pre-release schema to"
+ " new schema.")
+parser.add_argument(
+ "input",
+ type=str,
+ help="Input TensorFlow lite file in `.json`, `.bin` or `.tflite` format.")
+parser.add_argument(
+ "output",
+ type=str,
+ help="Output json or bin TensorFlow lite model compliant with"
+ "the new schema. Extension must be `.json`, `.bin` or `.tflite`.")
+
+
+# RAII Temporary Directory, because flatc doesn't allow direct use of tempfiles.
+@contextlib.contextmanager
+def TemporaryDirectoryResource():
+ temporary = tempfile.mkdtemp()
+ try:
+ yield temporary
+ finally:
+ shutil.rmtree(temporary)
+
+
+class Converter(object):
+ """Converts TensorFlow flatbuffer models from old to new version of schema.
+
+ This can convert between any version to the latest version. It uses
+ an incremental upgrade strategy to go from version to version.
+
+ Usage:
+ converter = Converter()
+ converter.Convert("a.tflite", "a.json")
+ converter.Convert("b.json", "b.tflite")
+ """
+
+ def __init__(self):
+ # TODO(aselle): make this work in the open source version with better
+ # path.
+ self._flatc_path = resource_loader.get_path_to_datafile(
+ "../../../../flatbuffers/flatc")
+
+ def FindSchema(base_name):
+ return resource_loader.get_path_to_datafile("%s" % base_name)
+
+ # Supported schemas for upgrade.
+ self._schemas = [
+ (0, FindSchema("schema_v0.fbs"), True, self._Upgrade0To1),
+ (1, FindSchema("schema_v1.fbs"), True, self._Upgrade1To2),
+ (2, FindSchema("schema_v2.fbs"), True, self._Upgrade2To3),
+ (3, FindSchema("schema_v3.fbs"), False, None) # Non-callable by design.
+ ]
+ # Ensure schemas are sorted, and extract latest version and upgrade
+ # dispatch function table.
+ self._schemas.sort()
+ self._new_version, self._new_schema = self._schemas[-1][:2]
+ self._upgrade_dispatch = dict(
+ (version, dispatch)
+ for version, unused1, unused2, dispatch in self._schemas)
+
+ def _Read(self, input_file, schema, raw_binary=False):
+ """Read a tflite model assuming the given flatbuffer schema.
+
+ If `input_file` is in bin, then we must use flatc to convert the schema
+ from binary to json.
+
+ Args:
+ input_file: a binary (flatbuffer) or json file to read from. Extension
+ must be `.tflite`, `.bin`, or `.json` for FlatBuffer Binary or
+ FlatBuffer JSON.
+ schema: which schema to use for reading
+ raw_binary: whether to assume raw_binary (versions previous to v3)
+ that lacked file_identifier require this.
+
+ Raises:
+ RuntimeError: When flatc cannot be invoked.
+ ValueError: When the extension is not json or bin.
+
+ Returns:
+ A dictionary representing the read tflite model.
+ """
+ raw_binary = ["--raw-binary"] if raw_binary else []
+ with TemporaryDirectoryResource() as tempdir:
+ basename = os.path.basename(input_file)
+ basename_no_extension, extension = os.path.splitext(basename)
+ if extension in [".bin", ".tflite"]:
+ # Convert to json using flatc
+ returncode = subprocess.call([
+ self._flatc_path,
+ "-t",
+ "--strict-json",
+ "--defaults-json",
+ ] + raw_binary + ["-o", tempdir, schema, "--", input_file])
+ if returncode != 0:
+ raise RuntimeError("flatc failed to convert from binary to json.")
+ json_file = os.path.join(tempdir, basename_no_extension + ".json")
+ if not os.path.exists(json_file):
+ raise RuntimeError("Could not find %r" % json_file)
+ elif extension == ".json":
+ json_file = input_file
+ else:
+ raise ValueError("Invalid extension on input file %r" % input_file)
+ return json.load(open(json_file))
+
+ def _Write(self, data, output_file):
+ """Output a json or bin version of the flatbuffer model.
+
+ Args:
+ data: Dict representing the TensorFlow Lite model to write.
+ output_file: filename to write the converted flatbuffer to. (json,
+ tflite, or bin extension is required).
+ Raises:
+ ValueError: When the extension is not json or bin
+ RuntimeError: When flatc fails to convert json data to binary.
+ """
+ _, extension = os.path.splitext(output_file)
+ with TemporaryDirectoryResource() as tempdir:
+ if extension == ".json":
+ json.dump(data, open(output_file, "w"), sort_keys=True, indent=2)
+ elif extension in [".tflite", ".bin"]:
+ input_json = os.path.join(tempdir, "temp.json")
+ with open(input_json, "w") as fp:
+ json.dump(data, fp, sort_keys=True, indent=2)
+ returncode = subprocess.call([
+ self._flatc_path, "-b", "--defaults-json", "--strict-json", "-o",
+ tempdir, self._new_schema, input_json
+ ])
+ if returncode != 0:
+ raise RuntimeError("flatc failed to convert upgraded json to binary.")
+
+ shutil.copy(os.path.join(tempdir, "temp.tflite"), output_file)
+ else:
+ raise ValueError("Invalid extension on output file %r" % output_file)
+
+ def _Upgrade0To1(self, data):
+ """Upgrade data from Version 0 to Version 1.
+
+ Changes: Added subgraphs (which contains a subset of formally global
+ entries).
+
+ Args:
+ data: Dictionary representing the TensorFlow lite data to be upgraded.
+ This will be modified in-place to be an upgraded version.
+ """
+ subgraph = {}
+ for key_to_promote in ["tensors", "operators", "inputs", "outputs"]:
+ subgraph[key_to_promote] = data[key_to_promote]
+ del data[key_to_promote]
+ data["subgraphs"] = [subgraph]
+
+ def _Upgrade1To2(self, data):
+ """Upgrade data from Version 1 to Version 2.
+
+ Changes: Rename operators to Conform to NN API.
+
+ Args:
+ data: Dictionary representing the TensorFlow lite data to be upgraded.
+ This will be modified in-place to be an upgraded version.
+ Raises:
+ ValueError: Throws when model builtins are numeric rather than symbols.
+ """
+
+ def RemapOperator(opcode_name):
+ """Go from old schema op name to new schema op name.
+
+ Args:
+ opcode_name: String representing the ops (see :schema.fbs).
+ Returns:
+ Converted opcode_name from V1 to V2.
+ """
+ old_name_to_new_name = {
+ "CONVOLUTION": "CONV_2D",
+ "DEPTHWISE_CONVOLUTION": "DEPTHWISE_CONV_2D",
+ "AVERAGE_POOL": "AVERAGE_POOL_2D",
+ "MAX_POOL": "MAX_POOL_2D",
+ "L2_POOL": "L2_POOL_2D",
+ "SIGMOID": "LOGISTIC",
+ "L2NORM": "L2_NORMALIZATION",
+ "LOCAL_RESPONSE_NORM": "LOCAL_RESPONSE_NORMALIZATION",
+ "Basic_RNN": "RNN",
+ }
+
+ return (old_name_to_new_name[opcode_name]
+ if opcode_name in old_name_to_new_name else opcode_name)
+
+ def RemapOperatorType(operator_type):
+ """Remap operator structs from old names to new names.
+
+ Args:
+ operator_type: String representing the builtin operator data type
+ string.
+ (see :schema.fbs).
+ Returns:
+ Upgraded builtin operator data type as a string.
+ """
+ old_to_new = {
+ "PoolOptions": "Pool2DOptions",
+ "DepthwiseConvolutionOptions": "DepthwiseConv2DOptions",
+ "ConvolutionOptions": "Conv2DOptions",
+ "LocalResponseNormOptions": "LocalResponseNormalizationOptions",
+ "BasicRNNOptions": "RNNOptions",
+ }
+ return (old_to_new[operator_type]
+ if operator_type in old_to_new else operator_type)
+
+ for subgraph in data["subgraphs"]:
+ for ops in subgraph["operators"]:
+ ops["builtin_options_type"] = RemapOperatorType(
+ ops["builtin_options_type"])
+
+ # Upgrade the operator codes
+ for operator_code in data["operator_codes"]:
+ if not isinstance(operator_code["builtin_code"], unicode):
+ raise ValueError("builtin_code %r is non-string. this usually means"
+ "your model has consistency problems." %
+ (operator_code["builtin_code"]))
+ operator_code["builtin_code"] = (RemapOperator(
+ operator_code["builtin_code"]))
+
+ def _Upgrade2To3(self, data):
+ """Upgrade data from Version 2 to Version 3.
+
+ Changed actual read-only tensor data to be in a buffers table instead
+ of inline with the tensor.
+
+ Args:
+ data: Dictionary representing the TensorFlow lite data to be upgraded.
+ This will be modified in-place to be an upgraded version.
+ """
+ buffers = [{"data": []}] # Start with 1 empty buffer
+ for subgraph in data["subgraphs"]:
+ if "tensors" not in subgraph:
+ continue
+ for tensor in subgraph["tensors"]:
+ if "data_buffer" not in tensor:
+ tensor["buffer"] = 0
+ else:
+ if tensor["data_buffer"]:
+ tensor[u"buffer"] = len(buffers)
+ buffers.append({"data": tensor["data_buffer"]})
+ else:
+ tensor["buffer"] = 0
+ del tensor["data_buffer"]
+ data["buffers"] = buffers
+
+ def _PerformUpgrade(self, data):
+ """Manipulate the `data` (parsed JSON) based on changes in format.
+
+ This incrementally will upgrade from version to version within data.
+
+ Args:
+ data: Dictionary representing the TensorFlow data. This will be upgraded
+ in place.
+ """
+ while data["version"] < self._new_version:
+ self._upgrade_dispatch[data["version"]](data)
+ data["version"] += 1
+
+ def Convert(self, input_file, output_file):
+ """Perform schema conversion from input_file to output_file.
+
+ Args:
+ input_file: Filename of TensorFlow Lite data to convert from. Must
+ be `.json` or `.bin` extension files for JSON or Binary forms of
+ the TensorFlow FlatBuffer schema.
+ output_file: Filename to write to. Extension also must be `.json`
+ or `.bin`.
+
+ Raises:
+ RuntimeError: Generated when none of the upgrader supported schemas
+ matche the `input_file` data.
+ """
+ # Read data in each schema (since they are incompatible). Version is
+ # always present. Use the read data that matches the version of the
+ # schema.
+ for version, schema, raw_binary, _ in self._schemas:
+ try:
+ data_candidate = self._Read(input_file, schema, raw_binary)
+ except RuntimeError:
+ continue # Skip and hope another schema works
+ if "version" not in data_candidate: # Assume version 1 if not present.
+ data_candidate["version"] = 1
+ elif data_candidate["version"] == 0: # Version 0 doesn't exist in wild.
+ data_candidate["version"] = 1
+
+ if data_candidate["version"] == version:
+ self._PerformUpgrade(data_candidate)
+ self._Write(data_candidate, output_file)
+ return
+ raise RuntimeError("No schema that the converter understands worked with "
+ "the data file you provided.")
+
+
+def main(argv):
+ del argv
+ Converter().Convert(FLAGS.input, FLAGS.output)
+
+
+if __name__ == "__main__":
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/lite/schema/upgrade_schema_test.py b/tensorflow/contrib/lite/schema/upgrade_schema_test.py
new file mode 100644
index 0000000000..475cdb9d8b
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/upgrade_schema_test.py
@@ -0,0 +1,317 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Testing for updating TensorFlow lite schema."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import tempfile
+from tensorflow.contrib.lite.schema import upgrade_schema as upgrade_schema_lib
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test as test_lib
+
+EMPTY_TEST_SCHEMA_V1 = {
+ "version": 1,
+ "operator_codes": [],
+ "subgraphs": [],
+}
+
+EMPTY_TEST_SCHEMA_V3 = {
+ "version": 3,
+ "operator_codes": [],
+ "subgraphs": [],
+ "buffers": [{
+ "data": []
+ }]
+}
+
+TEST_SCHEMA_V0 = {
+ "operator_codes": [],
+ "tensors": [],
+ "inputs": [],
+ "outputs": [],
+ "operators": [],
+ "version": 0
+}
+
+TEST_SCHEMA_V3 = {
+ "operator_codes": [],
+ "buffers": [{
+ "data": []
+ }],
+ "subgraphs": [{
+ "tensors": [],
+ "inputs": [],
+ "outputs": [],
+ "operators": [],
+ }],
+ "version":
+ 3
+}
+
+FULL_TEST_SCHEMA_V1 = {
+ "version":
+ 1,
+ "operator_codes": [
+ {
+ "builtin_code": "CONVOLUTION"
+ },
+ {
+ "builtin_code": "DEPTHWISE_CONVOLUTION"
+ },
+ {
+ "builtin_code": "AVERAGE_POOL"
+ },
+ {
+ "builtin_code": "MAX_POOL"
+ },
+ {
+ "builtin_code": "L2_POOL"
+ },
+ {
+ "builtin_code": "SIGMOID"
+ },
+ {
+ "builtin_code": "L2NORM"
+ },
+ {
+ "builtin_code": "LOCAL_RESPONSE_NORM"
+ },
+ {
+ "builtin_code": "ADD"
+ },
+ {
+ "builtin_code": "Basic_RNN"
+ },
+ ],
+ "subgraphs": [{
+ "operators": [
+ {
+ "builtin_options_type": "PoolOptions"
+ },
+ {
+ "builtin_options_type": "DepthwiseConvolutionOptions"
+ },
+ {
+ "builtin_options_type": "ConvolutionOptions"
+ },
+ {
+ "builtin_options_type": "LocalResponseNormOptions"
+ },
+ {
+ "builtin_options_type": "BasicRNNOptions"
+ },
+ ],
+ }],
+ "description":
+ "",
+}
+
+FULL_TEST_SCHEMA_V3 = {
+ "version":
+ 3,
+ "operator_codes": [
+ {
+ "builtin_code": "CONV_2D"
+ },
+ {
+ "builtin_code": "DEPTHWISE_CONV_2D"
+ },
+ {
+ "builtin_code": "AVERAGE_POOL_2D"
+ },
+ {
+ "builtin_code": "MAX_POOL_2D"
+ },
+ {
+ "builtin_code": "L2_POOL_2D"
+ },
+ {
+ "builtin_code": "LOGISTIC"
+ },
+ {
+ "builtin_code": "L2_NORMALIZATION"
+ },
+ {
+ "builtin_code": "LOCAL_RESPONSE_NORMALIZATION"
+ },
+ {
+ "builtin_code": "ADD"
+ },
+ {
+ "builtin_code": "RNN"
+ },
+ ],
+ "subgraphs": [{
+ "operators": [
+ {
+ "builtin_options_type": "Pool2DOptions"
+ },
+ {
+ "builtin_options_type": "DepthwiseConv2DOptions"
+ },
+ {
+ "builtin_options_type": "Conv2DOptions"
+ },
+ {
+ "builtin_options_type": "LocalResponseNormalizationOptions"
+ },
+ {
+ "builtin_options_type": "RNNOptions"
+ },
+ ],
+ }],
+ "description":
+ "",
+ "buffers": [{
+ "data": []
+ }]
+}
+
+BUFFER_TEST_V2 = {
+ "operator_codes": [],
+ "buffers": [],
+ "subgraphs": [{
+ "tensors": [
+ {
+ "data_buffer": [1, 2, 3, 4]
+ },
+ {
+ "data_buffer": [1, 2, 3, 4, 5, 6, 7, 8]
+ },
+ {
+ "data_buffer": []
+ },
+ ],
+ "inputs": [],
+ "outputs": [],
+ "operators": [],
+ }],
+ "version":
+ 2
+}
+
+BUFFER_TEST_V3 = {
+ "operator_codes": [],
+ "subgraphs": [{
+ "tensors": [
+ {
+ "buffer": 1
+ },
+ {
+ "buffer": 2
+ },
+ {
+ "buffer": 0
+ },
+ ],
+ "inputs": [],
+ "outputs": [],
+ "operators": [],
+ }],
+ "buffers": [
+ {
+ "data": []
+ },
+ {
+ "data": [1, 2, 3, 4]
+ },
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8]
+ },
+ ],
+ "version":
+ 3
+}
+
+
+def JsonDumpAndFlush(data, fp):
+ """Write the dictionary `data` to a JSON file `fp` (and flush).
+
+ Args:
+ data: in a dictionary that is JSON serializable.
+ fp: File-like object
+ """
+ json.dump(data, fp)
+ fp.flush()
+
+
+class TestSchemaUpgrade(test_util.TensorFlowTestCase):
+
+ def testNonExistantFile(self):
+ converter = upgrade_schema_lib.Converter()
+ non_existent = tempfile.mktemp(suffix=".json")
+ with self.assertRaisesRegexp(IOError, "No such file or directory"):
+ converter.Convert(non_existent, non_existent)
+
+ def testInvalidExtension(self):
+ converter = upgrade_schema_lib.Converter()
+ invalid_extension = tempfile.mktemp(suffix=".foo")
+ with self.assertRaisesRegexp(ValueError, "Invalid extension on input"):
+ converter.Convert(invalid_extension, invalid_extension)
+ with tempfile.NamedTemporaryFile(suffix=".json") as in_json:
+ JsonDumpAndFlush(EMPTY_TEST_SCHEMA_V1, in_json)
+ with self.assertRaisesRegexp(ValueError, "Invalid extension on output"):
+ converter.Convert(in_json.name, invalid_extension)
+
+ def CheckConversion(self, data_old, data_expected):
+ """Given a data dictionary, test upgrading to current version.
+
+ Args:
+ data_old: TFLite model as a dictionary (arbitrary version).
+ data_expected: TFLite model as a dictionary (upgraded).
+ """
+ converter = upgrade_schema_lib.Converter()
+ with tempfile.NamedTemporaryFile(suffix=".json") as in_json, \
+ tempfile.NamedTemporaryFile(suffix=".json") as out_json, \
+ tempfile.NamedTemporaryFile(suffix=".bin") as out_bin, \
+ tempfile.NamedTemporaryFile(suffix=".tflite") as out_tflite:
+ JsonDumpAndFlush(data_old, in_json)
+ # Test JSON output
+ converter.Convert(in_json.name, out_json.name)
+ # Test binary output
+ # Convert to .tflite and then to .bin and check if binary is equal
+ converter.Convert(in_json.name, out_tflite.name)
+ converter.Convert(out_tflite.name, out_bin.name)
+ self.assertEqual(open(out_bin.name).read(), open(out_tflite.name).read())
+ # Test that conversion actually produced successful new json.
+ converted_schema = json.load(out_json)
+ self.assertEqual(converted_schema, data_expected)
+
+ def testAlreadyUpgraded(self):
+ """A file already at version 3 should stay at version 3."""
+ self.CheckConversion(EMPTY_TEST_SCHEMA_V3, EMPTY_TEST_SCHEMA_V3)
+ self.CheckConversion(TEST_SCHEMA_V3, TEST_SCHEMA_V3)
+ self.CheckConversion(BUFFER_TEST_V3, BUFFER_TEST_V3)
+
+ # Disable this while we have incorrectly versioned structures around.
+ # def testV0Upgrade_IntroducesSubgraphs(self):
+ # """V0 did not have subgraphs; check to make sure they get introduced."""
+ # self.CheckConversion(TEST_SCHEMA_V0, TEST_SCHEMA_V3)
+
+ def testV1Upgrade_RenameOps(self):
+ """V1 had many different names for ops; check to make sure they rename."""
+ self.CheckConversion(EMPTY_TEST_SCHEMA_V1, EMPTY_TEST_SCHEMA_V3)
+ self.CheckConversion(FULL_TEST_SCHEMA_V1, FULL_TEST_SCHEMA_V3)
+
+ def testV2Upgrade_CreateBuffers(self):
+ """V2 did not have buffers; check to make sure they are created."""
+ self.CheckConversion(BUFFER_TEST_V2, BUFFER_TEST_V3)
+
+
+if __name__ == "__main__":
+ test_lib.main()
diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/contrib/lite/simple_memory_arena.cc
new file mode 100644
index 0000000000..4aab244989
--- /dev/null
+++ b/tensorflow/contrib/lite/simple_memory_arena.cc
@@ -0,0 +1,136 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/simple_memory_arena.h"
+
+#include <cstring>
+#include <limits>
+#include <vector>
+
+namespace {
+
+template <typename T>
+T AlignTo(size_t alignment, T offset) {
+ return offset % alignment == 0 ? offset
+ : offset + (alignment - offset % alignment);
+}
+
+} // namespace
+
+namespace tflite {
+
+TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context,
+ size_t alignment, size_t size,
+ ArenaAlloc* new_alloc) {
+ TF_LITE_ENSURE(context, alignment < arena_alignment_);
+
+ size_t current_top = 0;
+
+ if (!allocs_.empty()) {
+ auto last = allocs_.rbegin();
+ current_top = last->offset + last->size;
+ }
+
+ // If we don't find a better gap just allocate at the end of the buffer.
+ size_t best_offset = AlignTo(alignment, current_top);
+ size_t best_offset_fit = std::numeric_limits<size_t>::max();
+ auto best_insertion_it = allocs_.end();
+
+ // Go through the sorted allocs and look at the gaps between them.
+ size_t current_offset = 0;
+ for (auto it = allocs_.begin(); it != allocs_.end(); ++it) {
+ size_t aligned_current_offset = AlignTo(alignment, current_offset);
+ // If we found a gap larger than required size, and smaller than previous
+ // best fit, take it.
+ if (aligned_current_offset + size <= it->offset &&
+ it->offset - current_offset < best_offset_fit) {
+ best_offset = aligned_current_offset;
+ best_offset_fit = it->offset - current_offset;
+ best_insertion_it = it;
+ }
+ current_offset = it->offset + it->size;
+ }
+
+ // Update the required buffer size.
+ high_water_mark_ = std::max(high_water_mark_, best_offset + size);
+
+ new_alloc->offset = best_offset;
+ new_alloc->size = size;
+ allocs_.insert(best_insertion_it, *new_alloc);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus SimpleMemoryArena::Deallocate(TfLiteContext* context,
+ const ArenaAlloc& alloc) {
+ int erased_allocs_count = 0;
+ auto it = allocs_.begin();
+ while (it != allocs_.end()) {
+ if (it->offset == alloc.offset) {
+ TF_LITE_ENSURE_EQ(context, it->size, alloc.size);
+ erased_allocs_count++;
+ it = allocs_.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ TF_LITE_ENSURE_EQ(context, erased_allocs_count, 1);
+ return kTfLiteOk;
+}
+
+TfLiteStatus SimpleMemoryArena::Commit(TfLiteContext* context) {
+ size_t required_size = RequiredBufferSize();
+ if (required_size > underlying_buffer_size_) {
+ char* new_alloc = new char[required_size];
+ char* new_underlying_buffer_aligned_ptr = reinterpret_cast<char*>(
+ AlignTo(arena_alignment_, reinterpret_cast<intptr_t>(new_alloc)));
+
+ // If the arena had been previously allocated, copy over the old memory.
+ // Since Alloc pointers are offset based, they will remain valid in the new
+ // memory block.
+ if (high_water_mark_ > 0 && underlying_buffer_size_ > 0) {
+ size_t copy_amount = std::min(
+ underlying_buffer_.get() + underlying_buffer_size_ -
+ underlying_buffer_aligned_ptr_,
+ new_alloc + required_size - new_underlying_buffer_aligned_ptr);
+ memcpy(new_underlying_buffer_aligned_ptr, underlying_buffer_aligned_ptr_,
+ copy_amount);
+ }
+
+ underlying_buffer_.reset(new_alloc);
+ underlying_buffer_size_ = required_size;
+ underlying_buffer_aligned_ptr_ = new_underlying_buffer_aligned_ptr;
+ }
+ commited_ = true;
+ return underlying_buffer_ != nullptr ? kTfLiteOk : kTfLiteError;
+}
+
+TfLiteStatus SimpleMemoryArena::ResolveAlloc(TfLiteContext* context,
+ const ArenaAlloc& alloc,
+ char** output_ptr) {
+ TF_LITE_ENSURE(context, commited_);
+ TF_LITE_ENSURE(context, output_ptr != nullptr);
+ *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset;
+ return kTfLiteOk;
+}
+
+TfLiteStatus SimpleMemoryArena::Clear() {
+ commited_ = false;
+ high_water_mark_ = 0;
+ allocs_.clear();
+ return kTfLiteOk;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h
new file mode 100644
index 0000000000..0d0b7f9ff7
--- /dev/null
+++ b/tensorflow/contrib/lite/simple_memory_arena.h
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_
+
+#include <list>
+#include <memory>
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+
+// This little structure holds the offset and the size for a dynamic memory
+// allocation in the memory arena. When the arena is commited and the
+// underlying buffer is set, the alloc can be resolved into an actual memory
+// pointer.
+struct ArenaAlloc {
+ ArenaAlloc() : offset(0), size(0) {}
+
+ size_t offset;
+ size_t size;
+
+ inline bool operator<(const ArenaAlloc& other) const {
+ return offset < other.offset;
+ }
+};
+
+// This small class is responsible for allocating, dealocating and reusing
+// dynamic memory from a common underlying buffer. The arena can be used in
+// scenarios when the pattern of memory allocations and dealocations is
+// repetitive, e.g. running NN inference in multiple iterations.
+class SimpleMemoryArena {
+ public:
+ explicit SimpleMemoryArena(size_t arena_alignment)
+ : commited_(false),
+ arena_alignment_(arena_alignment),
+ high_water_mark_(0),
+ underlying_buffer_size_(0),
+ allocs_() {}
+
+ TfLiteStatus Allocate(TfLiteContext* context, size_t alignment, size_t size,
+ ArenaAlloc* new_alloc);
+
+ TfLiteStatus Deallocate(TfLiteContext* context, const ArenaAlloc& alloc);
+
+ inline size_t RequiredBufferSize() {
+ // Add in a small amount of padding to reduce the chance of resize events
+ // for small allocations.
+ size_t padding = arena_alignment_;
+ return arena_alignment_ + high_water_mark_ + padding;
+ }
+
+ TfLiteStatus Commit(TfLiteContext* context);
+
+ TfLiteStatus ResolveAlloc(TfLiteContext* context, const ArenaAlloc& alloc,
+ char** output_ptr);
+
+ TfLiteStatus Clear();
+
+ private:
+ bool commited_;
+ size_t arena_alignment_;
+ size_t high_water_mark_;
+ std::unique_ptr<char[]> underlying_buffer_;
+ size_t underlying_buffer_size_;
+ char* underlying_buffer_aligned_ptr_;
+ // TODO(maciekc): add list iterator to the ArenaAlloc to lookup quickly.
+ std::list<ArenaAlloc> allocs_;
+};
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_
diff --git a/tensorflow/contrib/lite/simple_memory_arena_test.cc b/tensorflow/contrib/lite/simple_memory_arena_test.cc
new file mode 100644
index 0000000000..ac676092c6
--- /dev/null
+++ b/tensorflow/contrib/lite/simple_memory_arena_test.cc
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/simple_memory_arena.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace {
+
+TEST(SimpleMemoryArenaTest, BasicArenaOperations) {
+ TfLiteContext context;
+ SimpleMemoryArena arena(64);
+ ArenaAlloc allocs[6];
+
+ arena.Allocate(&context, 32, 2047, &allocs[0]);
+ arena.Allocate(&context, 32, 2047, &allocs[1]);
+ arena.Allocate(&context, 32, 2047, &allocs[2]);
+ arena.Deallocate(&context, allocs[0]);
+ arena.Allocate(&context, 32, 1023, &allocs[3]);
+ arena.Allocate(&context, 32, 2047, &allocs[4]);
+ arena.Deallocate(&context, allocs[1]);
+ arena.Allocate(&context, 32, 1023, &allocs[5]);
+
+ EXPECT_EQ(allocs[0].offset, 0);
+ EXPECT_EQ(allocs[1].offset, 2048);
+ EXPECT_EQ(allocs[2].offset, 4096);
+ EXPECT_EQ(allocs[3].offset, 0);
+ EXPECT_EQ(allocs[4].offset, 6144);
+ EXPECT_EQ(allocs[5].offset, 1024);
+}
+
+TEST(SimpleMemoryArenaTest, TestAfterClear) {
+ TfLiteContext context;
+ SimpleMemoryArena arena(64);
+ ArenaAlloc allocs[9];
+
+ arena.Allocate(&context, 32, 2047, &allocs[0]);
+ arena.Allocate(&context, 32, 2047, &allocs[1]);
+ arena.Allocate(&context, 32, 2047, &allocs[2]);
+ arena.Commit(&context);
+
+ EXPECT_EQ(allocs[0].offset, 0);
+ EXPECT_EQ(allocs[1].offset, 2048);
+ EXPECT_EQ(allocs[2].offset, 4096);
+
+ arena.Clear();
+
+ // Test with smaller allocs.
+ arena.Allocate(&context, 32, 1023, &allocs[3]);
+ arena.Allocate(&context, 32, 1023, &allocs[4]);
+ arena.Allocate(&context, 32, 1023, &allocs[5]);
+ arena.Commit(&context);
+
+ EXPECT_EQ(allocs[3].offset, 0);
+ EXPECT_EQ(allocs[4].offset, 1024);
+ EXPECT_EQ(allocs[5].offset, 2048);
+
+ arena.Clear();
+
+ // Test larger allocs which should require a reallocation.
+ arena.Allocate(&context, 32, 4095, &allocs[6]);
+ arena.Allocate(&context, 32, 4095, &allocs[7]);
+ arena.Allocate(&context, 32, 4095, &allocs[8]);
+ arena.Commit(&context);
+
+ EXPECT_EQ(allocs[6].offset, 0);
+ EXPECT_EQ(allocs[7].offset, 4096);
+ EXPECT_EQ(allocs[8].offset, 8192);
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/string.h b/tensorflow/contrib/lite/string.h
new file mode 100644
index 0000000000..ecd6f04ec2
--- /dev/null
+++ b/tensorflow/contrib/lite/string.h
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Abstract string. We don't want even absl at this level.
+#ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_
+#define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_
+
+#include <string>
+#include "tensorflow/core/platform/platform.h"
+
+namespace tflite {
+
+#ifndef PLATFORM_GOOGLE
+using std::string;
+#endif
+
+} // namespace tflite
+
+#endif // _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_
diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc
new file mode 100644
index 0000000000..cd41299d38
--- /dev/null
+++ b/tensorflow/contrib/lite/string_util.cc
@@ -0,0 +1,117 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/string_util.h"
+
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+
+namespace tflite {
+namespace {
+
+// Convenient method to get pointer to int32_t.
+int32_t* GetIntPtr(char* ptr) { return reinterpret_cast<int32_t*>(ptr); }
+} // namespace
+
+void DynamicBuffer::AddString(const char* str, size_t len) {
+ data_.resize(data_.size() + len);
+ memcpy(data_.data() + offset_.back(), str, len);
+ offset_.push_back(offset_.back() + len);
+}
+
+void DynamicBuffer::AddString(const StringRef& string) {
+ AddString(string.str, string.len);
+}
+
+void DynamicBuffer::AddJoinedString(const std::vector<StringRef>& strings,
+ char separator) {
+ // Resize the data buffer.
+ int total_len = strings.size() - 1;
+ for (StringRef ref : strings) {
+ total_len += ref.len;
+ }
+ data_.resize(data_.size() + total_len);
+
+ int current_idx = 0;
+ for (StringRef ref : strings) {
+ char* dst = data_.data() + offset_.back() + current_idx;
+
+ // Fill separator if not first string.
+ if (current_idx != 0) {
+ *dst = separator;
+ ++dst;
+ ++current_idx;
+ }
+
+ // Fill content of the string.
+ memcpy(dst, ref.str, ref.len);
+ current_idx += ref.len;
+ }
+ offset_.push_back(offset_.back() + total_len);
+}
+
+void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) {
+ // Allocate sufficient memory to tensor buffer.
+ int32_t num_strings = offset_.size() - 1;
+ // Total bytes include:
+ // * size of content (data_.size)
+ // * offset of each tensor (sizeof(int32_t) * num_strings)
+ // * length of whole buffer (int32_t)
+ // * num of strings (int32_t).
+ int32_t bytes = data_.size() // size of content
+ + sizeof(int32_t) * (num_strings + 2); // size of header
+
+ // Output tensor will take over the ownership of tensor_buffer, and free it
+ // during Interpreter destruction.
+ char* tensor_buffer = static_cast<char*>(malloc(bytes));
+
+ // Set num of string
+ memcpy(tensor_buffer, &num_strings, sizeof(int32_t));
+
+ // Set offset of strings.
+ int32_t start = sizeof(int32_t) * (num_strings + 2);
+ for (int i = 0; i < offset_.size(); i++) {
+ int32_t offset = start + offset_[i];
+ memcpy(tensor_buffer + sizeof(int32_t) * (i + 1), &offset, sizeof(int32_t));
+ }
+
+ // Copy data of strings.
+ memcpy(tensor_buffer + start, data_.data(), data_.size());
+
+ // Set tensor content pointer to tensor_buffer, and release original data.
+ auto dims = TfLiteIntArrayCreate(1);
+ dims->data[0] = num_strings;
+ TfLiteTensorReset(tensor->type, tensor->name, dims, tensor->params,
+ tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation,
+ tensor);
+}
+
+int GetStringCount(const TfLiteTensor* tensor) {
+ // The first integers in the raw buffer is the number of strings.
+ return *GetIntPtr(tensor->data.raw);
+}
+
+StringRef GetString(const TfLiteTensor* tensor, int string_index) {
+ int32_t* offset =
+ GetIntPtr(tensor->data.raw + sizeof(int32_t) * (string_index + 1));
+ return {
+ tensor->data.raw + (*offset),
+ (*(offset + 1)) - (*offset),
+ };
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/contrib/lite/string_util.h
new file mode 100644
index 0000000000..12872d1123
--- /dev/null
+++ b/tensorflow/contrib/lite/string_util.h
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Util methods to read and write String tensors.
+// String tensors are considered to be char tensor with protocol.
+// [0, 3] 4 bytes: N, num of strings in the tensor in little endian.
+// [(i+1)*4, (i+1)*4+3] 4 bytes: offset of i-th string in little endian.
+// [(N+2)*4, (N+2)*4+3] 4 bytes: length of the whole char buffer.
+// [offset(i), offset(i+1) - 1] : content of i-th string.
+// Example of a string tensor:
+// [
+// 2, 0, 0, 0, # 2 strings.
+// 16, 0, 0, 0, # 0-th string starts from index 12.
+// 18, 0, 0, 0, # 1-st string starts from index 18.
+// 18, 0, 0, 0, # total length of array.
+// 'A', 'B', # 0-th string [16..17]: "AB"
+// ] # 1-th string, empty
+//
+// A typical usage:
+// In op.Eval(context, node):
+// DynamicBuffer buf;
+// # Add string "AB" to tensor, string is stored in dynamic buffer.
+// buf.AddString("AB", 2);
+// # Write content of DynamicBuffer to tensor in format of string tensor
+// # described above.
+// buf.WriteToTensor(tensor)
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_
+
+#include <vector>
+
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/string.h"
+
+namespace tflite {
+
+// Convenient structure to store string pointer and length.
+typedef struct {
+ char* str;
+ int len;
+} StringRef;
+
+// DynamicBuffer holds temporary buffer that will be used to create a dynamic
+// tensor. A typical usage is to initialize a DynamicBuffer object, fill in
+// content and call CreateStringTensor in op.Eval().
+class DynamicBuffer {
+ public:
+ DynamicBuffer() : offset_({0}) {}
+
+ // Add string to dynamic buffer by resizing the buffer and copying the data.
+ void AddString(const StringRef& string);
+
+ // Add string to dynamic buffer by resizing the buffer and copying the data.
+ void AddString(const char* str, size_t len);
+
+ // Join a list of string with separator, and add as a single string to the
+ // buffer.
+ void AddJoinedString(const std::vector<StringRef>& strings, char separator);
+
+ // Fill content into a string tensor.
+ void WriteToTensor(TfLiteTensor* tensor);
+
+ private:
+ // Data buffer to store contents of strings, not including headers.
+ std::vector<char> data_;
+ // Offset of the starting index of each string in data buffer.
+ std::vector<int32_t> offset_;
+};
+
+// Return num of strings in a String tensor.
+int GetStringCount(const TfLiteTensor* tensor);
+
+// Get String pointer and length of index-th string in tensor.
+// NOTE: This will not create a copy of string data.
+StringRef GetString(const TfLiteTensor* tensor, int string_index);
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_
diff --git a/tensorflow/contrib/lite/string_util_test.cc b/tensorflow/contrib/lite/string_util_test.cc
new file mode 100644
index 0000000000..5c351638dc
--- /dev/null
+++ b/tensorflow/contrib/lite/string_util_test.cc
@@ -0,0 +1,117 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/string_util.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+
+namespace tflite {
+
+TEST(StringUtil, TestStringUtil) {
+ Interpreter interpreter;
+ interpreter.AddTensors(3);
+
+ TfLiteTensor* t0 = interpreter.tensor(0);
+ t0->type = kTfLiteString;
+ t0->allocation_type = kTfLiteDynamic;
+
+ TfLiteTensor* t1 = interpreter.tensor(1);
+ t1->type = kTfLiteString;
+ t1->allocation_type = kTfLiteDynamic;
+
+ char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'X', 'Y', 'Z'};
+
+ interpreter.SetTensorParametersReadOnly(2, kTfLiteString, "", {1}, {}, data,
+ 15);
+ TfLiteTensor* t2 = interpreter.tensor(2);
+ interpreter.AllocateTensors();
+
+ char s0[] = "ABC";
+ string s1 = "DEFG";
+ char s2[] = "";
+
+ // Write strings to tensors
+ DynamicBuffer buf0;
+ buf0.AddString(s0, 3);
+ DynamicBuffer buf1;
+ buf1.AddString(s1.data(), s1.length());
+ buf0.AddString(s2, 0);
+ buf0.WriteToTensor(t0);
+ buf1.WriteToTensor(t1);
+
+ // Read strings from tensors.
+ ASSERT_EQ(GetStringCount(t0), 2);
+ StringRef str_ref;
+ str_ref = GetString(t0, 0);
+ ASSERT_EQ(string(str_ref.str, str_ref.len), "ABC");
+ str_ref = GetString(t0, 1);
+ ASSERT_EQ(string(str_ref.str, str_ref.len), "");
+ ASSERT_EQ(t0->bytes, 19);
+
+ ASSERT_EQ(GetStringCount(t1), 1);
+ str_ref = GetString(t1, 0);
+ ASSERT_EQ(string(str_ref.str, str_ref.len), "DEFG");
+ ASSERT_EQ(t1->bytes, 16);
+
+ ASSERT_EQ(GetStringCount(t2), 1);
+ str_ref = GetString(t2, 0);
+ ASSERT_EQ(string(str_ref.str, str_ref.len), "XYZ");
+ ASSERT_EQ(t2->bytes, 15);
+}
+
+TEST(StringUtil, TestAddJoinedString) {
+ Interpreter interpreter;
+ interpreter.AddTensors(1);
+ TfLiteTensor* t0 = interpreter.tensor(0);
+ t0->type = kTfLiteString;
+ t0->allocation_type = kTfLiteDynamic;
+
+ char s0[] = "ABC";
+ char s1[] = "DEFG";
+ char s2[] = "";
+ char s3[] = "XYZ";
+
+ DynamicBuffer buf;
+ buf.AddJoinedString({{s0, 3}, {s1, 4}, {s2, 0}, {s3, 3}}, ' ');
+ buf.WriteToTensor(t0);
+
+ ASSERT_EQ(GetStringCount(t0), 1);
+ StringRef str_ref;
+ str_ref = GetString(t0, 0);
+ ASSERT_EQ(string(str_ref.str, str_ref.len), "ABC DEFG XYZ");
+ ASSERT_EQ(t0->bytes, 25);
+}
+
+TEST(StringUtil, TestEmptyList) {
+ Interpreter interpreter;
+ interpreter.AddTensors(1);
+ TfLiteTensor* t0 = interpreter.tensor(0);
+ t0->type = kTfLiteString;
+ t0->allocation_type = kTfLiteDynamic;
+ DynamicBuffer buf;
+ buf.WriteToTensor(t0);
+
+ ASSERT_EQ(GetStringCount(t0), 0);
+ ASSERT_EQ(t0->bytes, 8);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/testdata/0_subgraphs.bin b/tensorflow/contrib/lite/testdata/0_subgraphs.bin
new file mode 100644
index 0000000000..5606898d7f
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/0_subgraphs.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testdata/2_subgraphs.bin b/tensorflow/contrib/lite/testdata/2_subgraphs.bin
new file mode 100644
index 0000000000..07308ba62b
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/2_subgraphs.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testdata/empty_model.bin b/tensorflow/contrib/lite/testdata/empty_model.bin
new file mode 100644
index 0000000000..1762ca3938
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/empty_model.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testdata/multi_add.bin b/tensorflow/contrib/lite/testdata/multi_add.bin
new file mode 100644
index 0000000000..e5048a3281
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/multi_add.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testdata/multi_add.json b/tensorflow/contrib/lite/testdata/multi_add.json
new file mode 100644
index 0000000000..97b931dba8
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/multi_add.json
@@ -0,0 +1,46 @@
+{
+ "version": 1,
+ "operator_codes": [
+ {
+ "builtin_code": "ADD"
+ }
+ ],
+ "subgraphs": [
+ {
+ "tensors": [
+ { "shape": [ 1, 8, 8, 3 ], "name": "a" },
+ { "shape": [ 1, 8, 8, 3 ], "name": "b" },
+ { "shape": [ 1, 8, 8, 3 ], "name": "c" },
+ { "shape": [ 1, 8, 8, 3 ], "name": "d" },
+ { "shape": [ 1, 8, 8, 3 ], "name": "i" },
+ { "shape": [ 1, 8, 8, 3 ], "name": "x" },
+ { "shape": [ 1, 8, 8, 3 ], "name": "y" }
+ ],
+ "inputs": [ 0, 1, 2, 3 ],
+ "outputs": [ 5, 6 ],
+ "operators": [
+ {
+ "inputs": [ 1, 2 ],
+ "outputs": [ 4 ],
+ "builtin_options_type": "AddOptions",
+ "builtin_options": {
+ }
+ },
+ {
+ "inputs": [ 0, 4 ],
+ "outputs": [ 5 ],
+ "builtin_options_type": "AddOptions",
+ "builtin_options": {
+ }
+ },
+ {
+ "inputs": [ 3, 4 ],
+ "outputs": [ 6 ],
+ "builtin_options_type": "AddOptions",
+ "builtin_options": {
+ }
+ }
+ ]
+ }
+ ]
+}
diff --git a/tensorflow/contrib/lite/testdata/no_subgraphs.bin b/tensorflow/contrib/lite/testdata/no_subgraphs.bin
new file mode 100644
index 0000000000..5606898d7f
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/no_subgraphs.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testdata/test_model.bin b/tensorflow/contrib/lite/testdata/test_model.bin
new file mode 100644
index 0000000000..2878b1f96e
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/test_model.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testdata/test_model_broken.bin b/tensorflow/contrib/lite/testdata/test_model_broken.bin
new file mode 100644
index 0000000000..9fd050cd4a
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/test_model_broken.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testdata/test_model_broken.json b/tensorflow/contrib/lite/testdata/test_model_broken.json
new file mode 100644
index 0000000000..b701eb9a25
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/test_model_broken.json
@@ -0,0 +1,62 @@
+{
+ "subgraphs": [
+ {
+ "inputs": [0, 1],
+ "outputs": [2, 3],
+ "operators": [
+ {
+ "opcode_index": 0,
+ "inputs": [0,1],
+ "outputs": [2]
+ },
+ {
+ "opcode_index": 1,
+ "inputs": [2],
+ "outputs": [3]
+ }
+ ],
+ "tensors": [
+ {
+ "shape" : [
+ 2
+ ],
+ "type" : "FLOAT32",
+ "name" : "input0",
+ "data_buffer" : [1,0,0,0]
+ },
+ {
+ "shape" : [
+ 3
+ ],
+ "type" : "FLOAT32",
+ "name" : "input1",
+ "data_buffer" : []
+ },
+ {
+ "shape" : [
+ 3
+ ],
+ "type" : "FLOAT32",
+ "name" : "out1",
+ "data_buffer" : []
+ },
+ {
+ "shape" : [
+ 3
+ ],
+ "type" : "FLOAT32",
+ "name" : "out2",
+ "data_buffer" : []
+ }
+ ],
+ }
+ ],
+ "operator_codes": [
+ {
+ "builtin_code": 0
+ },
+ {
+ "custom_code": "testing_op"
+ }
+ ]
+}
diff --git a/tensorflow/contrib/lite/testdata/two_subgraphs.bin b/tensorflow/contrib/lite/testdata/two_subgraphs.bin
new file mode 100644
index 0000000000..07308ba62b
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/two_subgraphs.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
new file mode 100644
index 0000000000..5e40a13d3c
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -0,0 +1,213 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow/contrib/lite:build_def.bzl",
+ "gen_zipped_test_files",
+)
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+gen_zipped_test_files(
+ name = "optest",
+ files = [
+ "add.zip",
+ "avg_pool.zip",
+ "concat.zip",
+ "constant.zip",
+ "control_dep.zip",
+ "conv.zip",
+ "depthwiseconv.zip",
+ "fully_connected.zip",
+ "fused_batch_norm.zip",
+ "global_batch_norm.zip",
+ "l2_pool.zip",
+ "l2norm.zip",
+ "local_response_norm.zip",
+ "max_pool.zip",
+ "mul.zip",
+ "relu.zip",
+ "relu1.zip",
+ "relu6.zip",
+ "reshape.zip",
+ "resize_bilinear.zip",
+ "sigmoid.zip",
+ "softmax.zip",
+ "space_to_depth.zip",
+ ],
+)
+
+py_binary(
+ name = "generate_examples",
+ srcs = ["generate_examples.py"],
+ data = [
+ "//tensorflow/contrib/lite/toco",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":generate_examples_report",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:graph_util",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
+ name = "generate_examples_report",
+ srcs = ["generate_examples_report.py"],
+ srcs_version = "PY2AND3",
+)
+
+cc_library(
+ name = "parse_testdata_lib",
+ srcs = ["parse_testdata.cc"],
+ hdrs = ["parse_testdata.h"],
+ deps = [
+ ":message",
+ ":split",
+ ":test_runner",
+ "//tensorflow/contrib/lite:framework",
+ ],
+)
+
+cc_library(
+ name = "message",
+ srcs = ["message.cc"],
+ hdrs = ["message.h"],
+ deps = [":tokenize"],
+)
+
+cc_test(
+ name = "message_test",
+ srcs = ["message_test.cc"],
+ deps = [
+ ":message",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "split",
+ srcs = ["split.cc"],
+ hdrs = ["split.h"],
+ deps = [
+ "//tensorflow/contrib/lite:string",
+ ],
+)
+
+cc_test(
+ name = "split_test",
+ size = "small",
+ srcs = ["split_test.cc"],
+ deps = [
+ ":split",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "tflite_driver",
+ srcs = ["tflite_driver.cc"],
+ hdrs = ["tflite_driver.h"],
+ deps = [
+ ":split",
+ ":test_runner",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ],
+)
+
+cc_test(
+ name = "tflite_driver_test",
+ size = "small",
+ srcs = ["tflite_driver_test.cc"],
+ data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
+ deps = [
+ ":tflite_driver",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "tokenize",
+ srcs = ["tokenize.cc"],
+ hdrs = ["tokenize.h"],
+ deps = [
+ "//tensorflow/contrib/lite:string",
+ ],
+)
+
+cc_test(
+ name = "tokenize_test",
+ srcs = ["tokenize_test.cc"],
+ deps = [
+ ":tokenize",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "test_runner",
+ hdrs = ["test_runner.h"],
+ deps = [
+ "//tensorflow/contrib/lite:string",
+ ],
+)
+
+cc_test(
+ name = "test_runner_test",
+ srcs = ["test_runner_test.cc"],
+ deps = [
+ ":test_runner",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_binary(
+ name = "nnapi_example",
+ srcs = ["nnapi_example.cc"],
+ deps = [
+ ":parse_testdata_lib",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/nnapi:nnapi_lib",
+ ],
+)
+
+tf_cc_test(
+ name = "generated_examples_zip_test",
+ size = "medium",
+ srcs = ["generated_examples_zip_test.cc"],
+ data = [":optest"],
+ shard_count = 10,
+ deps = [
+ ":parse_testdata_lib",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "@com_google_googletest//:gtest",
+ "@com_googlesource_code_re2//:re2",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
new file mode 100644
index 0000000000..86540d58a6
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -0,0 +1,1189 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Generate a series of TensorFlow graphs that become tflite test cases.
+
+Usage:
+
+generate_examples <output directory> zipped
+
+bazel run //tensorflow/contrib/lite/testing:generate_examples
+ third_party/tensorflow/contrib/lite/testing/generated_examples zipped
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import itertools
+import os
+import re
+import sys
+import tempfile
+import traceback
+import zipfile
+import numpy as np
+from six import StringIO
+import tensorflow as tf
+from google.protobuf import text_format
+# TODO(aselle): switch to TensorFlow's resource_loader
+from tensorflow.contrib.lite.testing import generate_examples_report as report_lib
+from tensorflow.python.framework import graph_util as tf_graph_util
+
+parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
+parser.add_argument("output_path",
+ help="Directory where the outputs will be go.")
+# TODO(ahentz): remove this flag
+parser.add_argument("type", help="zipped")
+parser.add_argument("--zip_to_output",
+ type=str,
+ help="Particular zip to output.",
+ required=False)
+parser.add_argument("--toco",
+ type=str,
+ help="Path to toco tool.",
+ required=True)
+parser.add_argument(
+ "--known_bugs_are_errors",
+ action="store_true",
+ help=("If a particular model is affected by a known bug,"
+ " count it as a toco error."))
+parser.add_argument(
+ "--ignore_toco_errors",
+ action="store_true",
+ help="Raise an exception if any toco error is encountered.")
+parser.add_argument(
+ "--save_graphdefs",
+ action="store_true",
+ help="Include intermediate graphdefs in the output zip files.")
+
+
+RANDOM_SEED = 342
+TEST_INPUT_DEPTH = 3
+
+
+# A map from regular expression to bug number. Any test failure with label
+# matching the expression will be considered due to the corresponding bug.
+KNOWN_BUGS = {
+ # TOCO doesn't support scalars as input.
+ r"relu.*input_shape=\[\]": "67587484",
+ r"sigmoid.*input_shape=\[\]": "67645668",
+ # Concat doesn't work with a single input tensor
+ r"concat.*num_tensors=1": "67378344",
+ # Transposition in MatMul is not supported.
+ r"fully_connected.*transpose_.=True": "67586970",
+ # Softmax graphs are too complex.
+ r"softmax.*dim=0": "67749831",
+ r"softmax.*input_shape=\[1,3,4,3\]": "67749831",
+ # SpaceToDepth only supports float32.
+ r"space_to_depth.*(float16|int32|uint8|int64)": "68018134",
+}
+
+
+def toco_options(data_types,
+ input_arrays,
+ output_arrays,
+ shapes,
+ drop_control_dependency):
+ """Create TOCO options to process a model.
+
+ Args:
+ data_types: input and inference types used by TOCO.
+ input_arrays: names of the input tensors
+ output_arrays: name of the output tensors
+ shapes: shapes of the input tensors
+ drop_control_dependency: whether to ignore control dependency nodes.
+
+ Returns:
+ the options in a string.
+ """
+ shape_str = ":".join([",".join(str(y) for y in x) for x in shapes])
+ inference_type = "FLOAT"
+ # TODO(ahentz): if we get multi-input quantization to work we need this
+ # to change
+ if data_types[0] == "QUANTIZED_UINT8":
+ inference_type = "QUANTIZED_UINT8"
+ s = (" --input_types=%s" % ",".join(data_types) +
+ " --inference_type=%s" % inference_type +
+ " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" +
+ " --input_arrays=%s" % ",".join(input_arrays) +
+ " --input_shapes=%s" % shape_str +
+ " --output_arrays=%s" % ",".join(output_arrays))
+ if drop_control_dependency:
+ s += " --drop_control_dependency"
+ return s
+
+
+def write_toco_options(filename,
+ data_types,
+ input_arrays,
+ output_arrays,
+ shapes,
+ drop_control_dependency=False):
+ """Create TOCO options to process a model.
+
+ Args:
+ filename: Filename to write the options to.
+ data_types: input and inference types used by TOCO.
+ input_arrays: names of the input tensors
+ output_arrays: names of the output tensors
+ shapes: shapes of the input tensors
+ drop_control_dependency: whether to ignore control dependency nodes.
+ """
+ with open(filename, "w") as fp:
+ fp.write(
+ toco_options(
+ data_types=data_types,
+ input_arrays=input_arrays,
+ output_arrays=output_arrays,
+ shapes=shapes,
+ drop_control_dependency=drop_control_dependency))
+
+
+def write_examples(fp, examples):
+ """Given a list `examples`, write a text format representation.
+
+ The file format is csv like with a simple repeated pattern. We would ike
+ to use proto here, but we can't yet due to interfacing with the Android
+ team using this format.
+
+ Args:
+ fp: File-like object to write to.
+ examples: Example dictionary consiting of keys "inputs" and "outputs"
+ """
+
+ def write_tensor(fp, x):
+ """Write tensor in file format supported by TFLITE example."""
+ fp.write("dtype,%s\n" % x.dtype)
+ fp.write("shape," + ",".join(map(str, x.shape)) + "\n")
+ # Output 9 digits after the point to ensure the precision is good enough.
+ values = ["{:.9f}".format(value) for value in list(x.flatten())]
+ fp.write("values," + ",".join(values) + "\n")
+
+ fp.write("test_cases,%d\n" % len(examples))
+ for example in examples:
+ fp.write("inputs,%d\n" % len(example["inputs"]))
+ for i in example["inputs"]:
+ write_tensor(fp, i)
+ fp.write("outputs,%d\n" % len(example["outputs"]))
+ for i in example["outputs"]:
+ write_tensor(fp, i)
+
+
+def write_test_cases(fp, model_name, examples):
+ """Given a dictionary of `examples`, write a text format representation.
+
+ The file format is protocol-buffer-like, even though we don't use proto due
+ to the needs of the Android team.
+
+ Args:
+ fp: File-like object to write to.
+ model_name: Filename where the model was written to, relative to filename.
+ examples: Example dictionary consiting of keys "inputs" and "outputs"
+ """
+
+ fp.write("load_model: %s\n" % os.path.basename(model_name))
+ for example in examples:
+ fp.write("reshape {\n")
+ for t in example["inputs"]:
+ fp.write(" input: \"" + ",".join(map(str, t.shape)) + "\"\n")
+ fp.write("}\n")
+ fp.write("invoke {\n")
+
+ for t in example["inputs"]:
+ values = ["{:.9f}".format(value) for value in list(t.flatten())]
+ fp.write(" input: \"" + ",".join(values) + "\"\n")
+ for t in example["outputs"]:
+ values = ["{:.9f}".format(value) for value in list(t.flatten())]
+ fp.write(" output: \"" + ",".join(values) + "\"\n")
+ fp.write("}\n")
+
+
+_TF_TYPE_INFO = {
+ tf.float32: (np.float32, "FLOAT"),
+ tf.float16: (np.float16, "FLOAT"),
+ tf.int32: (np.int32, "INT32"),
+ tf.uint8: (np.uint8, "QUANTIZED_UINT8"),
+ tf.int64: (np.int64, "INT64"),
+}
+
+
+def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
+ """Build tensor data spreading the range [min_value, max_value)."""
+
+ if dtype in _TF_TYPE_INFO:
+ dtype = _TF_TYPE_INFO[dtype][0]
+
+ if dtype in (tf.float32, tf.float16):
+ value = (max_value-min_value)*np.random.random_sample(shape)+min_value
+ elif dtype in (tf.int32, tf.uint8, tf.int64):
+ value = np.random.random_integers(min_value, max_value, shape)
+ return value.astype(dtype)
+
+
+def freeze_graph(session, outputs):
+ """Freeze the current graph.
+
+ Args:
+ session: Tensorflow sessions containing the graph
+ outputs: List of output tensors
+
+ Returns:
+ The frozen graph_def.
+ """
+ return tf_graph_util.convert_variables_to_constants(
+ session, session.graph.as_graph_def(), [x.op.name for x in outputs])
+
+
+def make_control_dep_tests(zip_path):
+ """Make a set of tests that use control dependencies."""
+
+ test_parameters = [{
+ "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ filter_value = tf.zeros((3, 3, TEST_INPUT_DEPTH, 8), tf.float32)
+ assert_op = tf.assert_greater_equal(input_tensor, input_tensor - 1)
+ with tf.control_dependencies([assert_op]):
+ out = tf.nn.conv2d(input_tensor, filter_value,
+ strides=(1, 1, 1, 1), padding="SAME")
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(tf.float32, parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs,
+ drop_control_dependency=True)
+
+
+def toco_convert(graph_def_str, input_tensors, output_tensors,
+ drop_control_dependency=False):
+ """Convert a model's graph def into a tflite model.
+
+ NOTE: this currently shells out to the toco binary, but we would like
+ convert to Python API tooling in the future.
+
+ Args:
+ graph_def_str: Graph def proto in serialized string format.
+ input_tensors: List of input tensor tuples `(name, shape, type)`
+ output_tensors: List of output tensors (names)
+ drop_control_dependency: whether to ignore control dependency nodes.
+
+ Returns:
+ output tflite model, log_txt from conversion
+ or None, log_txt if it did not convert properly.
+ """
+ data_types = [_TF_TYPE_INFO[x[2]][1] for x in input_tensors]
+ opts = toco_options(
+ data_types=data_types,
+ input_arrays=[x[0] for x in input_tensors],
+ shapes=[x[1] for x in input_tensors],
+ output_arrays=output_tensors,
+ drop_control_dependency=drop_control_dependency)
+
+ with tempfile.NamedTemporaryFile() as graphdef_file, \
+ tempfile.NamedTemporaryFile() as output_file, \
+ tempfile.NamedTemporaryFile("w+") as stdout_file:
+ graphdef_file.write(graph_def_str)
+ graphdef_file.flush()
+
+ # TODO(aselle): Switch this to subprocess at some point.
+ cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" %
+ (bin_path, graphdef_file.name, output_file.name, opts,
+ stdout_file.name))
+ exit_code = os.system(cmd)
+ log = (
+ cmd + "exited with code %d" % exit_code + "\n------------------\n" +
+ stdout_file.read())
+ return (None if exit_code != 0 else output_file.read()), log
+
+
+def make_zip_of_tests(zip_path,
+ test_parameters,
+ make_graph,
+ make_test_inputs,
+ drop_control_dependency=False):
+ """Helper to make a zip file of a bunch of TensorFlow models.
+
+ This does a cartestian product of the dictionary of test_parameters and
+ calls make_graph() for each item in the cartestian product set.
+ If the graph is built successfully, then make_test_inputs() is called to
+ build expected input/output value pairs. The model is then converted to tflite
+ with toco, and the examples are serialized with the tflite model into a zip
+ file (2 files per item in the cartesian product set).
+
+ Args:
+ zip_path: Path of zip file to write
+ test_parameters: Dictionary mapping to lists for each parameter.
+ e.g. `{"strides": [[1,3,3,1], [1,2,2,1]], "foo": [1.2, 1.3]}`
+ make_graph: function that takes current parameters and returns tuple
+ `[input1, input2, ...], [output1, output2, ...]`
+ make_test_inputs: function taking `curr_params`, `session`, `input_tensors`,
+ `output_tensors` and returns tuple `(input_values, output_values)`.
+ drop_control_dependency: whether to ignore control dependency nodes.
+ Raises:
+ RuntimeError: if there are toco errors that can't be ignored.
+ """
+
+ # TODO(aselle): Make this allow multiple inputs outputs.
+ archive = zipfile.PyZipFile(zip_path, "w")
+ zip_manifest = []
+ convert_report = []
+ toco_errors = 0
+ for parameters in test_parameters:
+ keys = parameters.keys()
+ for curr in itertools.product(*parameters.values()):
+ label = zip_path.replace(".zip", "") + (",".join(
+ "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", ""))
+ if label[0] == "/":
+ label = label[1:]
+ param_dict = dict(zip(keys, curr))
+
+ def build_example(label, param_dict_real):
+ """Build the model with parameter values set in param_dict_real.
+
+ Args:
+ label: Label of the model (i.e. the filename in the zip).
+ param_dict_real: Parameter dictionary (arguments to the factories
+ make_graph and make_test_inputs)
+ Returns:
+ (tflite_model_binary, report) where tflite_model_binary is the
+ serialized flatbuffer as a string and report is a dictionary with
+ keys `toco_log` (log of toco conversion), `tf_log` (log of tf
+ conversion), `toco` (a string of success status of the conversion),
+ `tf` (a string success status of the conversion).
+ """
+
+ np.random.seed(RANDOM_SEED)
+ report = {"toco": report_lib.NOTRUN, "tf": report_lib.FAILED}
+
+ # Build graph
+ report["tf_log"] = ""
+ report["toco_log"] = ""
+ tf.reset_default_graph()
+
+ try:
+ inputs, outputs = make_graph(param_dict_real)
+ except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
+ ValueError):
+ report["tf_log"] += traceback.format_exc()
+ return None, report
+
+ sess = tf.Session()
+ try:
+ baseline_inputs, baseline_outputs = (make_test_inputs(
+ param_dict_real, sess, inputs, outputs))
+ except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
+ ValueError):
+ report["tf_log"] += traceback.format_exc()
+ return None, report
+ report["toco"] = report_lib.FAILED
+ report["tf"] = report_lib.SUCCESS
+
+ # Convert graph to toco
+ tflite_model_binary, toco_log = toco_convert(
+ sess.graph_def.SerializeToString(),
+ [(input_tensor.name.split(":")[0], input_tensor.get_shape(),
+ input_tensor.dtype) for input_tensor in inputs],
+ [out.name.split(":")[0]
+ for out in outputs], drop_control_dependency)
+ report["toco"] = (report_lib.SUCCESS if tflite_model_binary is not None
+ else report_lib.FAILED)
+ report["toco_log"] = toco_log
+
+ if FLAGS.save_graphdefs:
+ archive.writestr(label + ".pb",
+ text_format.MessageToString(sess.graph_def),
+ zipfile.ZIP_DEFLATED)
+
+ if tflite_model_binary:
+ archive.writestr(label + ".bin", tflite_model_binary,
+ zipfile.ZIP_DEFLATED)
+ example = {"inputs": baseline_inputs, "outputs": baseline_outputs}
+
+ example_fp = StringIO()
+ write_examples(example_fp, [example])
+ archive.writestr(label + ".inputs",
+ example_fp.getvalue(), zipfile.ZIP_DEFLATED)
+
+ example_fp2 = StringIO()
+ write_test_cases(example_fp2, label + ".bin", [example])
+ archive.writestr(label + "_tests.txt",
+ example_fp2.getvalue(), zipfile.ZIP_DEFLATED)
+
+ zip_manifest.append(label + "\n")
+
+ return tflite_model_binary, report
+
+ _, report = build_example(label, param_dict)
+
+ if report["toco"] == report_lib.FAILED:
+ ignore_error = False
+ if not FLAGS.known_bugs_are_errors:
+ for pattern, bug_number in KNOWN_BUGS.items():
+ if re.search(pattern, label):
+ print("Ignored TOCO error due to bug %s" % bug_number)
+ ignore_error = True
+ if not ignore_error:
+ toco_errors += 1
+ print("-----------------\ntoco error!\n%s\n-----------------\n" %
+ report["toco_log"])
+
+ convert_report.append((param_dict, report))
+ report_io = StringIO()
+ report_lib.make_report_table(report_io, zip_path, convert_report)
+ archive.writestr("report.html", report_io.getvalue())
+
+ archive.writestr("manifest.txt", "".join(zip_manifest), zipfile.ZIP_DEFLATED)
+
+ # Log statistics of what succeeded
+ total_conversions = len(convert_report)
+ tf_success = sum(1 for x in convert_report
+ if x[1]["tf"] == report_lib.SUCCESS)
+ toco_success = sum(1 for x in convert_report
+ if x[1]["toco"] == report_lib.SUCCESS)
+ percent = 0
+ if tf_success > 0:
+ percent = float(toco_success) / float(tf_success) * 100.
+ tf.logging.info(("Archive %s Considered %d graphs, %d TF evaluated graphs "
+ " and %d TOCO converted graphs (%.1f%%"), zip_path,
+ total_conversions, tf_success, toco_success, percent)
+
+ if not FLAGS.ignore_toco_errors and toco_errors > 0:
+ raise RuntimeError(
+ "Found %d errors while generating toco models" % toco_errors)
+
+
+def make_pool_tests(pool_op_in):
+ """Make a set of tests to do average pooling.
+
+ Args:
+ pool_op_in: TensorFlow pooling operation to test i.e. `tf.nn.avg_pool`.
+
+ Returns:
+ A function representing the true generator (after curried pool_op_in).
+ """
+
+ pool_op = pool_op_in
+
+ def f(zip_path):
+ """Actual function that generates examples.
+
+ Args:
+ zip_path: path to write zip to.
+ """
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "ksize": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]],
+ "strides": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]],
+ # TODO(aselle): should add in a degenerate shape (e.g. [1, 0, 1, 1]).
+ "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"], # TODO(aselle): NCHW would be good
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ out = pool_op(
+ input_tensor,
+ ksize=parameters["ksize"],
+ strides=parameters["strides"],
+ data_format=parameters["data_format"],
+ padding=parameters["padding"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(tf.float32, parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+ return f
+
+
+def make_relu_tests(zip_path):
+ """Make a set of tests to do relu."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
+ [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ out = tf.nn.relu(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-4, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_relu1_tests(zip_path):
+ """Make a set of tests to do relu1."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3],
+ [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ # Note that the following is not supported:
+ # out = tf.maximum(-1.0, tf.minimum(input_tensor, 1.0))
+ out = tf.minimum(1.0, tf.maximum(input_tensor, -1.0))
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-3, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_relu6_tests(zip_path):
+ """Make a set of tests to do relu6."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3],
+ [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ out = tf.nn.relu(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-3, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+# This function tests various TensorFLow functions that generates Const op,
+# including `tf.ones`, `tf.zeros` and random functions.
+def make_constant_tests(zip_path):
+ """Make a set of tests to do constant ops."""
+
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape": [[1], [2], [1, 1, 1, 1], [2, 2, 2, 2]],
+ }]
+
+ def build_graph(parameters):
+ # Since Toco & Tflite can't have a single constant op in the entire graph,
+ # this test adds a zero tesnor with a constant op tensor.
+ input1 = tf.placeholder(dtype=parameters["dtype"], name="input1",
+ shape=parameters["input_shape"])
+ out = tf.ones(parameters["input_shape"], dtype=parameters["dtype"]) + input1
+ return [input1], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input1 = np.zeros(parameters["input_shape"],
+ dtype=_TF_TYPE_INFO[parameters["dtype"]][0])
+ return [input1], sess.run(outputs, feed_dict={inputs[0]: input1})
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_add_tests(zip_path):
+ """Make a set of tests to do add with and without broadcast."""
+
+ # These parameters are split because we don't support broadcasting.
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape_1": [[1, 3, 4, 3]],
+ "input_shape_2": [[1, 3, 4, 3]],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape_1": [[5]],
+ "input_shape_2": [[5]],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape_1": [[1, 3, 4, 3]],
+ "input_shape_2": [[3]],
+ }]
+
+ def build_graph(parameters):
+ input1 = tf.placeholder(dtype=parameters["dtype"], name="input1",
+ shape=parameters["input_shape_1"])
+ input2 = tf.placeholder(dtype=parameters["dtype"], name="input2",
+ shape=parameters["input_shape_2"])
+ out = tf.add(input1, input2)
+ return [input1, input2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input1 = create_tensor_data(parameters["dtype"],
+ parameters["input_shape_1"])
+ input2 = create_tensor_data(parameters["dtype"],
+ parameters["input_shape_2"])
+ return [input1, input2], sess.run(
+ outputs, feed_dict={
+ inputs[0]: input1,
+ inputs[1]: input2
+ })
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_mul_tests(zip_path):
+ """Make a set of tests to do mul with and without broadcast."""
+
+ # These parameters are split because we don't support broadcasting.
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape_1": [[1, 3, 4, 3]],
+ "input_shape_2": [[1, 3, 4, 3]],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape_1": [[5]],
+ "input_shape_2": [[5]],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape_1": [[1, 3, 4, 3]],
+ "input_shape_2": [[3]],
+ }]
+
+ def build_graph(parameters):
+ input1 = tf.placeholder(dtype=parameters["dtype"], name="input1",
+ shape=parameters["input_shape_1"])
+ input2 = tf.placeholder(dtype=parameters["dtype"], name="input2",
+ shape=parameters["input_shape_2"])
+ out = tf.multiply(input1, input2)
+ return [input1, input2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input1 = create_tensor_data(parameters["dtype"],
+ parameters["input_shape_1"])
+ input2 = create_tensor_data(parameters["dtype"],
+ parameters["input_shape_2"])
+ return [input1, input2], sess.run(
+ outputs, feed_dict={inputs[0]: input1,
+ inputs[1]: input2})
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_global_batch_norm_tests(zip_path):
+ """Make a set of tests to do batch_norm_with_global_normalization."""
+
+ test_parameters = [{
+ "dtype": [tf.float32],
+ "input_shape": [[1, 1, 6, 2], [3, 4, 5, 4]],
+ "epsilon": [0.1, 0.0001],
+ "scale_after": [True, False],
+ }]
+
+ def build_graph(parameters):
+ """Build the global batch norm testing graph."""
+ input_shape = parameters["input_shape"]
+ scale_shape = input_shape[3]
+
+ scale = create_tensor_data(parameters["dtype"], scale_shape)
+ offset = create_tensor_data(parameters["dtype"], scale_shape)
+ mean = create_tensor_data(parameters["dtype"], scale_shape)
+ variance = create_tensor_data(parameters["dtype"], scale_shape)
+
+ x = create_tensor_data(parameters["dtype"], parameters["input_shape"])
+ x_norm = tf.nn.batch_norm_with_global_normalization(
+ x, mean, variance, scale, offset,
+ parameters["epsilon"], parameters["scale_after"])
+
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.add(input_tensor, x_norm)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_fused_batch_norm_tests(zip_path):
+ """Make a set of tests to do fused_batch_norm."""
+
+ test_parameters = [{
+ "dtype": [tf.float32],
+ "input_shape": [[1, 1, 6, 2]],
+ "epsilon": [0.001, 0.1],
+ }]
+
+ def build_graph(parameters):
+ """Build the testing graph for fused batch normalization."""
+ input_shape = parameters["input_shape"]
+ scale_shape = input_shape[3]
+
+ scale = create_tensor_data(parameters["dtype"], scale_shape)
+ offset = create_tensor_data(parameters["dtype"], scale_shape)
+ mean = create_tensor_data(parameters["dtype"], scale_shape)
+ variance = create_tensor_data(parameters["dtype"], scale_shape)
+
+ x = create_tensor_data(parameters["dtype"], parameters["input_shape"])
+ [x_norm, _, _] = tf.nn.fused_batch_norm(
+ x, scale, offset, mean, variance,
+ parameters["epsilon"], data_format="NHWC", is_training=False)
+
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.add(input_tensor, x_norm)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_conv_tests(zip_path):
+ """Make a set of tests to do convolution."""
+
+ test_parameters = [{
+ "input_shape": [[1, 3, 4, 3]],
+ "filter_shape": [[1, 1, 3, 2]],
+ "strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"], # TODO(aselle): NCHW would be good
+ }, {
+ "input_shape": [[2, 14, 14, 2]],
+ "filter_shape": [[6, 6, 2, 2]],
+ "strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"], # TODO(aselle): NCHW would be good
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ filter_values = create_tensor_data(np.float32, parameters["filter_shape"])
+ out = tf.nn.conv2d(input_tensor, filter_values,
+ strides=parameters["strides"],
+ padding=parameters["padding"],
+ data_format=parameters["data_format"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(np.float32, parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_depthwiseconv_tests(zip_path):
+ """Make a set of tests to do convolution."""
+
+ # Tensorflow only supports equal strides
+ test_parameters = [{
+ "input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]],
+ "filter_size": [[1, 1], [1, 2], [3, 3]],
+ "strides": [[1, 1, 1, 1], [1, 3, 3, 1]],
+ "channel_multiplier": [1, 2],
+ "rate": [[1, 1]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"],
+ }, {
+ "input_shape": [[1, 3, 4, 3]],
+ "filter_size": [[1, 1]],
+ "strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1]
+ "channel_multiplier": [2],
+ "rate": [[2, 2]], # Only [1, 1] is supported
+ "padding": ["SAME"],
+ "data_format": ["NHWC"],
+ }]
+
+ def build_graph(parameters):
+ """Build a depthwise conv graph given `parameters`."""
+ input_shape = parameters["input_shape"]
+ filter_size = parameters["filter_size"]
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=input_shape)
+ filter_shape = filter_size + [
+ input_shape[3], parameters["channel_multiplier"]]
+ filter_values = create_tensor_data(np.float32, filter_shape)
+ out = tf.nn.depthwise_conv2d(
+ input_tensor, filter_values,
+ strides=parameters["strides"],
+ rate=parameters["rate"],
+ padding=parameters["padding"],
+ data_format=parameters["data_format"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(np.float32, parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_concatenation_tests(zip_path):
+ """Make a set of tests to do concatenatinon."""
+
+ test_parameters = [{
+ "base_shape": [[1, 3, 4, 3], [3, 4]],
+ "num_tensors": [1, 2, 3, 4, 5, 6],
+ "axis": [0, 1, 2, 3],
+ }]
+
+ def get_shape(parameters, delta):
+ """Return a tweaked version of 'base_shape'."""
+ axis = parameters["axis"]
+ shape = parameters["base_shape"][:]
+ if axis < len(shape):
+ shape[axis] += delta
+ return shape
+
+ def build_graph(parameters):
+ all_tensors = []
+ for n in range(0, parameters["num_tensors"]):
+ input_tensor = tf.placeholder(dtype=tf.float32, name=("input%d" % n),
+ shape=get_shape(parameters, n))
+ all_tensors.append(input_tensor)
+ out = tf.concat(all_tensors, parameters["axis"])
+ return all_tensors, [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ all_values = []
+ for n in range(0, parameters["num_tensors"]):
+ input_values = create_tensor_data(np.float32,
+ get_shape(parameters, n))
+ all_values.append(input_values)
+ return all_values, sess.run(
+ outputs, feed_dict=dict(zip(inputs, all_values)))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_fully_connected_tests(zip_path):
+ """Make a set of tests to do fully_connected."""
+
+ test_parameters = [{
+ "shape1": [[3, 3]],
+ "shape2": [[3, 3]],
+ "transpose_a": [True, False],
+ "transpose_b": [True, False],
+ }, {
+ "shape1": [[4, 4], [1, 4], [4]],
+ "shape2": [[4, 4], [4, 1], [4]],
+ "transpose_a": [False],
+ "transpose_b": [False],
+ }, {
+ "shape1": [[40, 37]],
+ "shape2": [[37, 40]],
+ "transpose_a": [False],
+ "transpose_b": [False],
+
+ }]
+
+ def build_graph(parameters):
+ input_tensor1 = tf.placeholder(dtype=tf.float32, name="input1",
+ shape=parameters["shape1"])
+ input_tensor2 = create_tensor_data(np.float32, parameters["shape2"])
+ out = tf.matmul(input_tensor1, input_tensor2,
+ transpose_a=parameters["transpose_a"],
+ transpose_b=parameters["transpose_b"])
+ return [input_tensor1], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values1 = create_tensor_data(np.float32, shape=parameters["shape1"])
+ return [input_values1], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values1])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_l2norm_tests(zip_path):
+ """Make a set of tests to do l2norm."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[5, 7], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3],
+ [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+ "dim": [0, 1, 2, 3, [2, 3], -2],
+ "epsilon": [None, 1e-12, 1e-3],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ if parameters["epsilon"]:
+ out = tf.nn.l2_normalize(
+ input_tensor, parameters["dim"], epsilon=parameters["epsilon"])
+ else:
+ out = tf.nn.l2_normalize(input_tensor, parameters["dim"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-4, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_local_response_norm_tests(zip_path):
+ """Make a set of tests to do local_response_norm."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3]],
+ "depth_radius": [None, 0, 1, 3, 4, 5],
+ "bias": [None, 0.1, 0.3, -0.1],
+ "alpha": [None, 1, 2, -3],
+ "beta": [None, 0.5, 0.25, 2],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ out = tf.nn.local_response_normalization(
+ input_tensor, depth_radius=parameters["depth_radius"],
+ bias=parameters["bias"], alpha=parameters["alpha"],
+ beta=parameters["beta"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-4, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_reshape_tests(zip_path):
+ """Make a set of tests to do reshape."""
+
+ # Alll shapes below are suitable for tensors with 420 elements.
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]],
+ "output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.reshape(input_tensor, shape=parameters["output_shape"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_resize_bilinear_tests(zip_path):
+ """Make a set of tests to do resize_bilinear."""
+
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape": [[1, 3, 4, 3], [1, 10, 2, 1]],
+ "size": [[1, 1], [4, 3], [2, 2], [5, 6]],
+ "align_corners": [None, True, False],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.image.resize_bilinear(input_tensor, size=parameters["size"],
+ align_corners=parameters["align_corners"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_sigmoid_tests(zip_path):
+ """Make a set of tests to do sigmoid."""
+
+ test_parameters = [{
+ "dtype": [tf.float32],
+ "input_shape": [[1, 3, 4, 3], [4], [], [1, 2, 3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.sigmoid(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_softmax_tests(zip_path):
+ """Make a set of tests to do softmax."""
+
+ test_parameters = [{
+ "dtype": [tf.float32],
+ "input_shape": [[1, 3, 4, 3], [2, 3]],
+ "dim": [-1, 0],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape": [[4, 7]],
+ "dim": [-1, 1],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.nn.softmax(input_tensor, dim=parameters["dim"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_space_to_depth_tests(zip_path):
+ """Make a set of tests to do space_to_depth."""
+
+ test_parameters = [{
+ "dtype": [tf.float32, tf.float16, tf.int32, tf.uint8, tf.int64],
+ "input_shape": [[2, 12, 24, 1]],
+ "block_size": [2, 3, 4],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.space_to_depth(input_tensor, block_size=parameters["block_size"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_l2_pool(input_tensor, ksize, strides, padding, data_format):
+ """Given an input perform a sequence of TensorFlow ops to produce l2pool."""
+ return tf.sqrt(tf.nn.avg_pool(
+ tf.square(input_tensor), ksize=ksize, strides=strides,
+ padding=padding, data_format=data_format))
+
+
+# Toco binary path provided by the generate rule.
+bin_path = None
+
+
+def main(unused_args):
+ global bin_path
+ def mkdir_if_not_exist(x):
+ if not os.path.isdir(x):
+ os.mkdir(x)
+ if not os.path.isdir(x):
+ raise RuntimeError("Failed to create dir %r" % x)
+
+ if FLAGS.type == "zipped":
+ opstest_path = os.path.join(FLAGS.output_path)
+ mkdir_if_not_exist(opstest_path)
+ def _path(filename):
+ return os.path.join(opstest_path, filename)
+
+ dispatch = {
+ "control_dep.zip": make_control_dep_tests,
+ "add.zip": make_add_tests,
+ "conv.zip": make_conv_tests,
+ "constant.zip": make_constant_tests,
+ "depthwiseconv.zip": make_depthwiseconv_tests,
+ "concat.zip": make_concatenation_tests,
+ "fully_connected.zip": make_fully_connected_tests,
+ "global_batch_norm.zip": make_global_batch_norm_tests,
+ "fused_batch_norm.zip": make_fused_batch_norm_tests,
+ "l2norm.zip": make_l2norm_tests,
+ "local_response_norm.zip": make_local_response_norm_tests,
+ "mul.zip": make_mul_tests,
+ "relu.zip": make_relu_tests,
+ "relu1.zip": make_relu1_tests,
+ "relu6.zip": make_relu6_tests,
+ "l2_pool.zip": make_pool_tests(make_l2_pool),
+ "avg_pool.zip": make_pool_tests(tf.nn.avg_pool),
+ "max_pool.zip": make_pool_tests(tf.nn.max_pool),
+ "reshape.zip": make_reshape_tests,
+ "resize_bilinear.zip": make_resize_bilinear_tests,
+ "sigmoid.zip": make_sigmoid_tests,
+ "softmax.zip": make_softmax_tests,
+ "space_to_depth.zip": make_space_to_depth_tests,
+ }
+ out = FLAGS.zip_to_output
+ bin_path = FLAGS.toco
+ if out in dispatch:
+ dispatch[out](_path(out))
+ else:
+ raise RuntimeError("Invalid zip to output %r" % out)
+
+ else:
+ raise RuntimeError("Invalid argument for type of generation.")
+
+
+if __name__ == "__main__":
+ FLAGS, unparsed = parser.parse_known_args()
+
+ if unparsed:
+ print("Usage: %s <path out> zipped <zip file to generate>")
+ else:
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/lite/testing/generate_examples_report.py b/tensorflow/contrib/lite/testing/generate_examples_report.py
new file mode 100644
index 0000000000..7bcf8cd86a
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/generate_examples_report.py
@@ -0,0 +1,125 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Make HTML tables that report where TF and TOCO failed to convert models.
+
+This is primarily used by generate_examples.py. See it or
+`make_report_table` for more details on usage.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import cgi
+import json
+
+FAILED = "FAILED"
+SUCCESS = "SUCCESS"
+NOTRUN = "NOTRUN"
+
+
+def make_report_table(fp, title, reports):
+ """Make an HTML report of the success/failure reports.
+
+ Args:
+ fp: File-like object in which to put the html.
+ title: "Title of the zip file this pertains to."
+ reports: a list of conversion attempts. (report_args, report_vals) i.e.
+ ({"shape": [1,2,3], "type": "tf.float32"},
+ {"tf": "SUCCESS", "toco": "FAILURE", "toco_log": "Unsupported type.",
+ "tf_log": ""})
+ """
+ # sort reports by if TOCO failure and then TF failure (reversed)
+ reports.sort(key=lambda x: x[1]["toco"], reverse=False)
+ reports.sort(key=lambda x: x[1]["tf"], reverse=True)
+ def result_cell(x, row, col):
+ """Produce a cell with the condition string `x`."""
+ s = cgi.escape(repr(x), quote=True)
+ color = "#44ff44" if x == SUCCESS else (
+ "#ff4444" if x == FAILED else "#eeeeee")
+ handler = "ShowLog(%d, %d)" % (row, col)
+ fp.write("<td style='background-color: %s' onclick='%s'>%s</td>\n" % (
+ color, handler, s))
+
+ fp.write("""<html>
+<head>
+<title>tflite report</title>
+<style>
+body { font-family: Arial; }
+th { background-color: #555555; color: #eeeeee; }
+td { vertical-align: top; }
+td.horiz {width: 50%;}
+pre { white-space: pre-wrap; word-break: keep-all; }
+table {width: 100%;}
+</style>
+</head>
+""")
+ # Write the log data to a javascript variable and also make a function
+ # in javascript to show the log when an item is clicked.
+ fp.write("<script> \n")
+ fp.write("""
+function ShowLog(row, col) {
+
+var log = document.getElementById("log");
+log.innerHTML = "<pre>" + data[row][col] + "</pre>";
+}
+""")
+ fp.write("var data = \n")
+ fp.write(json.dumps([[cgi.escape(x[1]["tf_log"], quote=True),
+ cgi.escape(x[1]["toco_log"], quote=True)]
+ for x in reports]))
+ fp.write(";</script>\n")
+
+ # Write the main table and use onclick on the items that have log items.
+ fp.write("""
+<body>
+<h1>TOCO Conversion</h1>
+<h2>%s</h2>
+""" % title)
+
+ # Get a list of keys that are in any of the records.
+ param_keys = {}
+ for params, _ in reports:
+ for k in params.keys():
+ param_keys[k] = True
+
+ fp.write("<table>\n")
+ fp.write("<tr><td class='horiz'>\n")
+ fp.write("<div style='height:1000px; overflow:auto'>\n")
+ fp.write("<table>\n")
+ fp.write("<tr>\n")
+ for p in param_keys:
+ fp.write("<th>%s</th>\n" % cgi.escape(p, quote=True))
+ fp.write("<th>TensorFlow</th>\n")
+ fp.write("<th>TOCO</th>\n")
+ fp.write("</tr>\n")
+ for idx, (params, vals) in enumerate(reports):
+ fp.write("<tr>\n")
+ for p in param_keys:
+ fp.write(" <td>%s</td>\n" % cgi.escape(repr(params[p]), quote=True))
+
+ result_cell(vals["tf"], idx, 0)
+ result_cell(vals["toco"], idx, 1)
+ fp.write("</tr>\n")
+ fp.write("</table>\n")
+ fp.write("</div>\n")
+ fp.write("</td>\n")
+ fp.write("<td class='horiz' id='log'></td></tr>\n")
+ fp.write("</table>\n")
+ fp.write("<script>\n")
+ fp.write("</script>\n")
+ fp.write("""
+ </body>
+ </html>
+ """)
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
new file mode 100644
index 0000000000..e7df97ee54
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -0,0 +1,279 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <cstdio>
+#include <cstdlib>
+#include <fstream>
+#include <map>
+#include <sstream>
+#include <gtest/gtest.h>
+#include "re2/re2.h"
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/testing/parse_testdata.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/subprocess.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+bool FLAGS_ignore_known_bugs = true;
+} // namespace
+
+namespace tflite {
+namespace testing {
+
+// TensorFlow system environment for file system called.
+tensorflow::Env* env = tensorflow::Env::Default();
+
+// List of tests that are expected to fail when
+// --test_arg=--ignore_known_bugs=false
+// Key is a substring of the test name and value is a bug number.
+// TODO(ahentz): make sure we clean this list up frequently.
+std::map<string, string> kBrokenTests = {
+ // Add doesn't support broadcasting.
+ {R"(addd.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
+ {R"(muld.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
+
+ // Add only supports float32. (and "constant" tests use Add)
+ {R"(addd.*int32)", "68808744"},
+ {R"(constant.*int32)", "68808744"},
+ {R"(mul.*int32)", "68808744"},
+
+ // Toco or TFLite has a bug to deal with some constant functions with
+ // more than 1 element.
+ {R"(constant.*input_shape=\[(2|2,2,2,2)\])", "68721522"},
+
+ // L2Norm only supports 4D tensors.
+ {R"(l2normdim=.*,epsilon=.*,input_shape=\[.,.\])", "67963684"},
+ {R"(l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
+
+ // L2Norm only works for dim=-1.
+ {R"(l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+
+ // ResizeBilinear looks completely incompatible with Tensorflow
+ {R"(resize_bilinear)", "67964336"},
+};
+
+// Allows test data to be unzipped into a temporary directory and makes
+// sure those temporary directories are removed later.
+class ZipEnvironment : public ::testing::Environment {
+ public:
+ ~ZipEnvironment() override {}
+
+ // Delete all temporary directories on teardown.
+ void TearDown() override {
+ for (const auto& dir : temporary_directories_) {
+ tensorflow::int64 undeleted_dirs, undeleted_files;
+ TF_CHECK_OK(
+ env->DeleteRecursively(dir, &undeleted_dirs, &undeleted_files));
+ }
+ temporary_directories_.clear();
+ }
+
+ // Unzip `zip` file into a new temporary directory `out_dir`.
+ tensorflow::Status UnZip(const std::string& zip, std::string* out_dir) {
+ string dir;
+ TF_CHECK_OK(MakeTemporaryDirectory(&dir));
+ tensorflow::SubProcess proc;
+ std::string unzip_binary =
+ "/usr/bin/unzip";
+ proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip.c_str()});
+ proc.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
+ proc.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
+ if (!proc.Start())
+ return tensorflow::Status(tensorflow::error::UNKNOWN,
+ "unzip couldn't start");
+ string out, err;
+ int status = proc.Communicate(nullptr, &out, &err);
+ if (WEXITSTATUS(status) == 0) {
+ *out_dir = dir;
+ return tensorflow::Status::OK();
+ } else {
+ return tensorflow::Status(tensorflow::error::UNKNOWN, "unzip failed");
+ }
+ }
+
+ private:
+ // Make a temporary directory and return its name in `temporary`.
+ tensorflow::Status MakeTemporaryDirectory(string* temporary) {
+ if (env->LocalTempFilename(temporary)) {
+ TF_CHECK_OK(env->CreateDir(*temporary));
+ temporary_directories_.push_back(*temporary);
+ return tensorflow::Status::OK();
+ }
+ return tensorflow::Status(tensorflow::error::UNKNOWN,
+ "make temporary directory failed");
+ }
+
+ std::vector<string> temporary_directories_;
+};
+
+// Return the singleton zip_environment.
+ZipEnvironment* zip_environment() {
+ static ZipEnvironment* env = new ZipEnvironment;
+ return env;
+}
+
+// Read the manifest.txt out of the unarchived zip file. Specifically
+// `original_file` is the original zip file for error messages. `dir` is
+// the temporary directory where the zip file has been unarchived and
+// `test_paths` is the list of test prefixes that were in the manifest.
+// Note, it is an error for a manifest to contain no tests.
+tensorflow::Status ReadManifest(const std::string& original_file,
+ const std::string& dir,
+ std::vector<std::string>* test_paths) {
+ // Read the newline delimited list of entries in the manifest.
+ std::ifstream manifest_fp(dir + "/manifest.txt");
+ std::string manifest((std::istreambuf_iterator<char>(manifest_fp)),
+ std::istreambuf_iterator<char>());
+ size_t pos = 0;
+ int added = 0;
+ while (true) {
+ size_t end_pos = manifest.find("\n", pos);
+ if (end_pos == std::string::npos) break;
+ std::string filename = manifest.substr(pos, end_pos - pos);
+ test_paths->push_back(dir + "/" + filename);
+ pos = end_pos + 1;
+ added += 1;
+ }
+ if (!added) {
+ std::string message = "Test had no examples: " + original_file;
+ return tensorflow::Status(tensorflow::error::UNKNOWN, message.c_str());
+ }
+ return tensorflow::Status::OK();
+}
+
+// Get a list of tests from a zip file `zip_file_name`.
+std::vector<std::string> UnarchiveZipAndFindTestNames(
+ const std::string& zip_file_name) {
+ std::string zip_file = ::tensorflow::testing::TensorFlowSrcRoot() +
+ "/contrib/lite/testing/optest/" + zip_file_name;
+ std::string decompress_tmp_dir;
+ TF_CHECK_OK(zip_environment()->UnZip(zip_file, &decompress_tmp_dir));
+ std::vector<std::string> stuff;
+ TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff));
+ return stuff;
+}
+
+class OpsTest : public ::testing::TestWithParam<std::string> {};
+
+TEST_P(OpsTest, RunStuff) {
+ std::string test_path = GetParam();
+ std::string tflite_file = test_path + ".bin";
+ std::string tflite_examples = test_path + ".inputs";
+ auto model = tflite::FlatBufferModel::BuildFromFile(tflite_file.c_str());
+ std::unique_ptr<tflite::Interpreter> interpreter;
+
+ tflite::ops::builtin::BuiltinOpResolver builtins;
+ ASSERT_EQ(tflite::InterpreterBuilder(*model, builtins)(&interpreter),
+ kTfLiteOk);
+
+ std::vector<tflite::testing::Example> examples;
+ ASSERT_EQ(tflite::testing::ParseExamples(tflite_examples.c_str(), &examples),
+ kTfLiteOk);
+
+ string bug_number;
+ for (const auto& p : kBrokenTests) {
+ if (RE2::PartialMatch(test_path, p.first)) {
+ bug_number = p.second;
+ }
+ }
+
+ for (const auto& example : examples) {
+ ASSERT_EQ(interpreter->inputs().size(), example.inputs.size());
+ auto result = [&]() {
+ TF_LITE_ENSURE_STATUS(FeedExample(interpreter.get(), example));
+ TF_LITE_ENSURE_STATUS(interpreter->Invoke());
+ TF_LITE_ENSURE_STATUS(CheckOutputs(interpreter.get(), example));
+ return kTfLiteOk;
+ }();
+
+ if (bug_number.empty()) {
+ ASSERT_EQ(result, kTfLiteOk);
+ } else {
+ if (FLAGS_ignore_known_bugs) {
+ ASSERT_EQ(result, kTfLiteError)
+ << "Not failing as expected dut to http://b/" << bug_number;
+ } else {
+ ASSERT_EQ(result, kTfLiteOk)
+ << "Possibly due to http://b/" << bug_number;
+ }
+ }
+ }
+}
+
+// Instantiate a test. This assumes `zip_base`.zip is a declared data file
+// of this test.
+#define INSTANTIATE_TESTS(zip_base) \
+ INSTANTIATE_TEST_CASE_P( \
+ zip_base, OpsTest, \
+ ::testing::ValuesIn(UnarchiveZipAndFindTestNames(#zip_base ".zip")));
+
+INSTANTIATE_TESTS(add)
+INSTANTIATE_TESTS(avg_pool)
+INSTANTIATE_TESTS(concat)
+INSTANTIATE_TESTS(constant)
+INSTANTIATE_TESTS(control_dep)
+INSTANTIATE_TESTS(conv)
+INSTANTIATE_TESTS(depthwiseconv)
+INSTANTIATE_TESTS(fully_connected)
+INSTANTIATE_TESTS(fused_batch_norm)
+INSTANTIATE_TESTS(global_batch_norm)
+INSTANTIATE_TESTS(l2norm)
+INSTANTIATE_TESTS(l2_pool)
+INSTANTIATE_TESTS(local_response_norm)
+INSTANTIATE_TESTS(max_pool)
+INSTANTIATE_TESTS(mul)
+INSTANTIATE_TESTS(relu)
+INSTANTIATE_TESTS(relu1)
+INSTANTIATE_TESTS(relu6)
+INSTANTIATE_TESTS(reshape)
+INSTANTIATE_TESTS(resize_bilinear)
+INSTANTIATE_TESTS(sigmoid)
+INSTANTIATE_TESTS(softmax)
+INSTANTIATE_TESTS(space_to_depth)
+
+} // namespace testing
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::AddGlobalTestEnvironment(tflite::testing::zip_environment());
+
+ std::vector<tensorflow::Flag> flags = {tensorflow::Flag(
+ "ignore_known_bugs", &FLAGS_ignore_known_bugs,
+ "If a particular model is affected by a known bug, the "
+ "corresponding test should expect the outputs to not match.")};
+ bool success = tensorflow::Flags::Parse(&argc, argv, flags);
+ if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) {
+ fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
+ return 1;
+ }
+
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/testing/message.cc b/tensorflow/contrib/lite/testing/message.cc
new file mode 100644
index 0000000000..03fae4bb86
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/message.cc
@@ -0,0 +1,96 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/message.h"
+
+#include <stack>
+
+#include "tensorflow/contrib/lite/testing/tokenize.h"
+
+namespace tflite {
+namespace testing {
+
+// A token processor that builds messages and forward calls to the current
+// message object. Place a new message at the top of the stack when it start
+// and remove it when it is finished.
+class MessageStack : public TokenProcessor {
+ public:
+ // Start a new MessageStack with the given first_node, which will be used to
+ // process freestanding fields and submessages.
+ explicit MessageStack(Message* first_node) {
+ nodes_.push(first_node);
+ valid_ = true;
+ }
+
+ void ConsumeToken(std::string* token) override {
+ if (!valid_) return;
+ Message* current_node = nodes_.top();
+ if (*token == "{") {
+ // This is the beginning of a new message, names after the previous token.
+ if (previous_token_.empty()) {
+ valid_ = false;
+ return;
+ }
+ nodes_.push(current_node ? current_node->AddChild(previous_token_)
+ : nullptr);
+ previous_token_.clear();
+ } else if (*token == "}") {
+ // A message is being completed. There should be no previous token. Note
+ // that the top-level message never closes, so we should always have at
+ // least one entry in the stack.
+ if (nodes_.size() == 1 || !previous_token_.empty()) {
+ valid_ = false;
+ return;
+ }
+ if (current_node) {
+ current_node->Finish();
+ }
+ nodes_.pop();
+ } else if (*token == ":") {
+ // We reached the end of the 'key' portion of a field. Store the token
+ // until we have the 'value' portion.
+ if (previous_token_.empty()) {
+ valid_ = false;
+ return;
+ }
+ } else {
+ if (previous_token_.empty()) {
+ previous_token_.swap(*token);
+ } else {
+ // This is the 'value' portion of a field. The previous token is the
+ // 'key'.
+ if (current_node) {
+ current_node->SetField(previous_token_, *token);
+ }
+ previous_token_.clear();
+ }
+ }
+ }
+
+ bool valid() const { return valid_; }
+
+ private:
+ std::stack<Message*> nodes_;
+ std::string previous_token_;
+ bool valid_;
+};
+
+bool Message::Read(std::istream* input, Message* message) {
+ MessageStack stack(message);
+ Tokenize(input, &stack);
+ return stack.valid();
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/message.h b/tensorflow/contrib/lite/testing/message.h
new file mode 100644
index 0000000000..78ef7e2cbe
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/message.h
@@ -0,0 +1,82 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace tflite {
+namespace testing {
+
+// A Message is a textual protobuf-like structure that looks like:
+// tag {
+// f : "values"
+// child {
+// a : 1
+// }
+// }
+// This class provides the framework for processing message but does not
+// associate any particular behavior to fields and submessage. In order
+// to properly parse a stream this class must be derived.
+class Message {
+ public:
+ // Reads a stream, tokenizes it and create a new message under the given
+ // top-level message. Returns true if the parsing succeeded.
+ static bool Read(std::istream* input, Message* message);
+
+ Message() {}
+ virtual ~Message() {}
+
+ // Called when a new field is found. For example, when:
+ // f : "values"
+ // is found, it triggers:
+ // SetField("f", "values");
+ virtual void SetField(const std::string& name, const std::string& value) {}
+
+ // Called when a submessage is started. For example, when:
+ // child {
+ // is found, it triggers
+ // AddChild("child");
+ // If nullptr is returned, the contents of the submessage will be ignored.
+ // Otherwise, the returned Message will be used to handle new fields and new
+ // submessages. The caller should not take ownership of the returned pointer.
+ virtual Message* AddChild(const std::string& name) { return nullptr; }
+
+ // Called when a submessage is completed, that is, whenever a '}' is found.
+ virtual void Finish() {}
+
+ protected:
+ // Takes ownership of the given pointer. Subclasses can use this method if
+ // they don't want to implement their own ownership semantics.
+ Message* Store(Message* n) {
+ children_.emplace_back(n);
+ return n;
+ }
+
+ // Returns a list of all owned submessages.
+ const std::vector<std::unique_ptr<Message>>& Children() const {
+ return children_;
+ }
+
+ private:
+ std::vector<std::unique_ptr<Message>> children_;
+};
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
diff --git a/tensorflow/contrib/lite/testing/message_test.cc b/tensorflow/contrib/lite/testing/message_test.cc
new file mode 100644
index 0000000000..fb6a49bd6f
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/message_test.cc
@@ -0,0 +1,121 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/message.h"
+
+#include <map>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+// A hierarchical, key-value store.
+class TestMessage : public Message {
+ public:
+ TestMessage() {}
+ explicit TestMessage(const std::string& text_to_parse) {
+ std::stringstream ss(text_to_parse);
+ finished_ = Message::Read(&ss, this);
+ }
+ void SetField(const std::string& name, const std::string& value) override {
+ fields_[name] = value;
+ }
+ Message* AddChild(const std::string& name) override {
+ TestMessage* m = new TestMessage;
+ m->name_ = name;
+ return Store(m);
+ }
+ void Finish() override { finished_ = true; }
+
+ int NumChildren() const { return Children().size(); }
+
+ const TestMessage* GetChild(int i) const {
+ return dynamic_cast<TestMessage*>(Children()[i].get());
+ }
+
+ int NumFields() const { return fields_.size(); }
+ const std::string& GetField(const std::string& key) const {
+ return fields_.at(key);
+ }
+
+ const std::string& name() const { return name_; }
+ bool finished() const { return finished_; }
+
+ protected:
+ std::string name_;
+ std::map<std::string, std::string> fields_;
+ bool finished_ = false;
+};
+
+TEST(MessageTest, Simple) {
+ TestMessage message("x{a:1 b:2} y{} z{c:3} d:4");
+ ASSERT_TRUE(message.finished());
+
+ ASSERT_EQ(message.NumFields(), 1);
+ EXPECT_EQ(message.GetField("d"), "4");
+
+ ASSERT_EQ(message.NumChildren(), 3);
+
+ auto* x = message.GetChild(0);
+ EXPECT_EQ(x->name(), "x");
+ ASSERT_EQ(x->NumFields(), 2);
+ EXPECT_EQ(x->GetField("a"), "1");
+ EXPECT_EQ(x->GetField("b"), "2");
+
+ auto* y = message.GetChild(1);
+ EXPECT_EQ(y->name(), "y");
+ ASSERT_EQ(y->NumFields(), 0);
+
+ auto* z = message.GetChild(2);
+ EXPECT_EQ(z->name(), "z");
+ ASSERT_EQ(z->NumFields(), 1);
+ EXPECT_EQ(z->GetField("c"), "3");
+}
+
+TEST(MessageTest, Unnamed) {
+ TestMessage message("x{c:3} {} y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 1);
+}
+
+TEST(MessageTest, TooManyBraces) {
+ TestMessage message("x{c:3} } y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 1);
+}
+
+TEST(MessageTest, LeftoverToken) {
+ TestMessage message("x{c:3} z{test} y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 2);
+}
+
+TEST(MessageTest, MissingKey) {
+ TestMessage message("x{c:3} z{:test} y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 2);
+}
+
+TEST(MessageTest, MissingValue) {
+ TestMessage message("x{c:3} z{test:} y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 2);
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/nnapi_example.cc b/tensorflow/contrib/lite/testing/nnapi_example.cc
new file mode 100644
index 0000000000..74f6cfc3de
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/nnapi_example.cc
@@ -0,0 +1,114 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// NOTE: this is an example driver that converts a tflite model to TensorFlow.
+// This is an example that will be integrated more tightly into tflite in
+// the future.
+//
+// Usage: bazel run -c opt \
+// tensorflow/contrib/lite/nnapi:nnapi_example -- <filename>
+//
+#include <cstdarg>
+#include <cstdio>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
+#include "tensorflow/contrib/lite/testing/parse_testdata.h"
+
+// TODO(aselle): FATAL leaves resources hanging.
+void FATAL(const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ vfprintf(stderr, format, args);
+ va_end(args);
+ fflush(stderr);
+ exit(1);
+}
+
+#define CHECK_TFLITE_SUCCESS(x) \
+ if (x != kTfLiteOk) { \
+ FATAL("Aborting since tflite returned failure."); \
+ }
+
+void Interpret(const char* filename, const char* examples_filename,
+ bool use_nnapi) {
+ // TODO(aselle): Resize of input image should go here
+ // ...
+ // For now I am allocating all tensors. This means I am fixed size.
+ // So I am not using the variable size ability yet.
+ fprintf(stderr, "example file %s\n", examples_filename);
+ std::vector<tflite::testing::Example> examples;
+ CHECK_TFLITE_SUCCESS(
+ tflite::testing::ParseExamples(examples_filename, &examples));
+
+ for (const tflite::testing::Example& example : examples) {
+ auto model = tflite::FlatBufferModel::BuildFromFile(filename);
+ if (!model) FATAL("Cannot read file %s\n", filename);
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ tflite::ops::builtin::BuiltinOpResolver builtins;
+
+ CHECK_TFLITE_SUCCESS(
+ tflite::InterpreterBuilder(*model, builtins)(&interpreter));
+
+ printf("Use nnapi is set to: %d\n", use_nnapi);
+ interpreter->UseNNAPI(use_nnapi);
+ CHECK_TFLITE_SUCCESS(
+ tflite::testing::FeedExample(interpreter.get(), example));
+
+ {
+ TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]);
+ if (float* data =
+ interpreter->typed_tensor<float>(interpreter->outputs()[0])) {
+ size_t num = tensor->bytes / sizeof(float);
+ for (float* p = data; p < data + num; p++) {
+ *p = 0;
+ }
+ }
+ }
+ interpreter->Invoke();
+
+ CHECK_TFLITE_SUCCESS(
+ tflite::testing::CheckOutputs(interpreter.get(), example));
+
+ printf("Result:\n");
+ TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]);
+ if (float* data =
+ interpreter->typed_tensor<float>(interpreter->outputs()[0])) {
+ size_t num = tensor->bytes / sizeof(float);
+ for (float* p = data; p < data + num; p++) {
+ printf(" %f", *p);
+ }
+ }
+ }
+}
+
+int main(int argc, char* argv[]) {
+ bool use_nnapi = true;
+ if (argc == 4) {
+ use_nnapi = strcmp(argv[3], "1") == 0 ? true : false;
+ }
+ if (argc < 3) {
+ fprintf(stderr,
+ "Compiled " __DATE__ __TIME__
+ "\n"
+ "Usage!!!: %s <tflite model> <examples to test> "
+ "{ use nn api i.e. 0,1}\n",
+ argv[0]);
+ return 1;
+ }
+ Interpret(argv[1], argv[2], use_nnapi);
+ return 0;
+}
diff --git a/tensorflow/contrib/lite/testing/parse_testdata.cc b/tensorflow/contrib/lite/testing/parse_testdata.cc
new file mode 100644
index 0000000000..2b67052cad
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/parse_testdata.cc
@@ -0,0 +1,335 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Parses tflite example input data.
+// Format is ASCII
+// TODO(aselle): Switch to protobuf, but the android team requested a simple
+// ASCII file.
+#include "tensorflow/contrib/lite/testing/parse_testdata.h"
+
+#include <cmath>
+#include <cstdint>
+#include <cstdio>
+#include <fstream>
+#include <iostream>
+#include <streambuf>
+
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/testing/message.h"
+#include "tensorflow/contrib/lite/testing/split.h"
+
+namespace tflite {
+namespace testing {
+namespace {
+
+// Fatal error if parse error occurs
+#define PARSE_CHECK_EQ(filename, current_line, x, y) \
+ if ((x) != (y)) { \
+ fprintf(stderr, "Parse Error @ %s:%d\n File %s\n Line %d, %s != %s\n", \
+ __FILE__, __LINE__, filename, current_line + 1, #x, #y); \
+ return kTfLiteError; \
+ }
+
+// Breakup a "," delimited line into a std::vector<std::string>.
+// This is extremely inefficient, and just used for testing code.
+// TODO(aselle): replace with absl when we use it.
+std::vector<std::string> ParseLine(const std::string& line) {
+ size_t pos = 0;
+ std::vector<std::string> elements;
+ while (true) {
+ size_t end = line.find(',', pos);
+ if (end == std::string::npos) {
+ elements.push_back(line.substr(pos));
+ break;
+ } else {
+ elements.push_back(line.substr(pos, end - pos));
+ }
+ pos = end + 1;
+ }
+ return elements;
+}
+
+} // namespace
+
+// Given a `filename`, produce a vector of Examples corresopnding
+// to test cases that can be applied to a tflite model.
+TfLiteStatus ParseExamples(const char* filename,
+ std::vector<Example>* examples) {
+ std::ifstream fp(filename);
+ if (!fp.good()) {
+ fprintf(stderr, "Could not read '%s'\n", filename);
+ return kTfLiteError;
+ }
+ std::string str((std::istreambuf_iterator<char>(fp)),
+ std::istreambuf_iterator<char>());
+ size_t pos = 0;
+
+ // \n and , delimit parse a file.
+ std::vector<std::vector<std::string>> csv;
+ while (true) {
+ size_t end = str.find('\n', pos);
+
+ if (end == std::string::npos) {
+ csv.emplace_back(ParseLine(str.substr(pos)));
+ break;
+ }
+ csv.emplace_back(ParseLine(str.substr(pos, end - pos)));
+ pos = end + 1;
+ }
+
+ int current_line = 0;
+ PARSE_CHECK_EQ(filename, current_line, csv[0][0], "test_cases");
+ int example_count = std::stoi(csv[0][1]);
+ current_line++;
+
+ auto parse_tensor = [&filename, &current_line,
+ &csv](FloatTensor* tensor_ptr) {
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "dtype");
+ current_line++;
+ // parse shape
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "shape");
+ size_t elements = 1;
+ FloatTensor& tensor = *tensor_ptr;
+
+ for (size_t i = 1; i < csv[current_line].size(); i++) {
+ const auto& shape_part_to_parse = csv[current_line][i];
+ if (shape_part_to_parse.empty()) {
+ // Case of a 0-dimensional shape
+ break;
+ }
+ int shape_part = std::stoi(shape_part_to_parse);
+ elements *= shape_part;
+ tensor.shape.push_back(shape_part);
+ }
+ current_line++;
+ // parse data
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line].size() - 1,
+ elements);
+ for (size_t i = 1; i < csv[current_line].size(); i++) {
+ tensor.flat_data.push_back(std::stof(csv[current_line][i]));
+ }
+ current_line++;
+
+ return kTfLiteOk;
+ };
+
+ for (int example_idx = 0; example_idx < example_count; example_idx++) {
+ Example example;
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "inputs");
+ int inputs = std::stoi(csv[current_line][1]);
+ current_line++;
+ // parse dtype
+ for (int input_index = 0; input_index < inputs; input_index++) {
+ example.inputs.push_back(FloatTensor());
+ TF_LITE_ENSURE_STATUS(parse_tensor(&example.inputs.back()));
+ }
+
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "outputs");
+ int outputs = std::stoi(csv[current_line][1]);
+ current_line++;
+ for (int input_index = 0; input_index < outputs; input_index++) {
+ example.outputs.push_back(FloatTensor());
+ TF_LITE_ENSURE_STATUS(parse_tensor(&example.outputs.back()));
+ }
+ examples->emplace_back(example);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus FeedExample(tflite::Interpreter* interpreter,
+ const Example& example) {
+ // Resize inputs to match example & allocate.
+ for (size_t i = 0; i < interpreter->inputs().size(); i++) {
+ int input_index = interpreter->inputs()[i];
+
+ TF_LITE_ENSURE_STATUS(
+ interpreter->ResizeInputTensor(input_index, example.inputs[i].shape));
+ }
+ TF_LITE_ENSURE_STATUS(interpreter->AllocateTensors());
+ // Copy data into tensors.
+ for (size_t i = 0; i < interpreter->inputs().size(); i++) {
+ int input_index = interpreter->inputs()[i];
+ if (float* data = interpreter->typed_tensor<float>(input_index)) {
+ for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
+ data[idx] = example.inputs[i].flat_data[idx];
+ }
+ } else if (int32_t* data =
+ interpreter->typed_tensor<int32_t>(input_index)) {
+ for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
+ data[idx] = example.inputs[i].flat_data[idx];
+ }
+ } else {
+ fprintf(stderr, "input[%zu] was not float or int data\n", i);
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
+ const Example& example) {
+ constexpr double kRelativeThreshold = 1e-2f;
+ constexpr double kAbsoluteThreshold = 1e-4f;
+
+ ErrorReporter* context = DefaultErrorReporter();
+ int model_outputs = interpreter->outputs().size();
+ TF_LITE_ENSURE_EQ(context, model_outputs, example.outputs.size());
+ for (size_t i = 0; i < interpreter->outputs().size(); i++) {
+ int output_index = interpreter->outputs()[i];
+ if (const float* data = interpreter->typed_tensor<float>(output_index)) {
+ for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
+ float computed = data[idx];
+ float reference = example.outputs[0].flat_data[idx];
+ float diff = std::abs(computed - reference);
+ bool error_is_large = false;
+ // For very small numbers, try absolute error, otherwise go with
+ // relative.
+ if (std::abs(reference) < kRelativeThreshold) {
+ error_is_large = (diff > kAbsoluteThreshold);
+ } else {
+ error_is_large = (diff > kRelativeThreshold * std::abs(reference));
+ }
+ if (error_is_large) {
+ fprintf(stdout, "output[%zu][%zu] did not match %f vs reference %f\n",
+ i, idx, data[idx], reference);
+ return kTfLiteError;
+ }
+ }
+ fprintf(stderr, "\n");
+ } else if (const int32_t* data =
+ interpreter->typed_tensor<int32_t>(output_index)) {
+ for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
+ int32_t computed = data[idx];
+ int32_t reference = example.outputs[0].flat_data[idx];
+ if (std::abs(computed - reference) > 0) {
+ fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %f\n",
+ i, idx, data[idx], example.outputs[0].flat_data[idx]);
+ return kTfLiteError;
+ }
+ }
+ fprintf(stderr, "\n");
+ } else {
+ fprintf(stderr, "output[%zu] was not float or int data\n", i);
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+// Process an 'invoke' message, triggering execution of the test runner, as
+// well as verification of outputs. An 'invoke' message looks like:
+// invoke {
+// id: xyz
+// input: 1,2,1,1,1,2,3,4
+// ouput: 4,5,6
+// }
+class Invoke : public Message {
+ public:
+ explicit Invoke(TestRunner* test_runner) : test_runner_(test_runner) {
+ expected_inputs_ = test_runner->GetInputs();
+ expected_outputs_ = test_runner->GetOutputs();
+ }
+
+ void SetField(const std::string& name, const std::string& value) override {
+ if (name == "id") {
+ test_runner_->SetInvocationId(value);
+ } else if (name == "input") {
+ if (expected_inputs_.empty()) {
+ return test_runner_->Invalidate("Too many inputs");
+ }
+ test_runner_->SetInput(*expected_inputs_.begin(), value);
+ expected_inputs_.erase(expected_inputs_.begin());
+ } else if (name == "output") {
+ if (expected_outputs_.empty()) {
+ return test_runner_->Invalidate("Too many outputs");
+ }
+ test_runner_->SetExpectation(*expected_outputs_.begin(), value);
+ expected_outputs_.erase(expected_outputs_.begin());
+ }
+ }
+ void Finish() override {
+ test_runner_->Invoke();
+ test_runner_->CheckResults();
+ }
+
+ private:
+ std::vector<int> expected_inputs_;
+ std::vector<int> expected_outputs_;
+
+ TestRunner* test_runner_;
+};
+
+// Process an 'reshape' message, triggering resizing of the input tensors via
+// the test runner. A 'reshape' message looks like:
+// reshape {
+// input: 1,2,1,1,1,2,3,4
+// }
+class Reshape : public Message {
+ public:
+ explicit Reshape(TestRunner* test_runner) : test_runner_(test_runner) {
+ expected_inputs_ = test_runner->GetInputs();
+ }
+
+ void SetField(const std::string& name, const std::string& value) override {
+ if (name == "input") {
+ if (expected_inputs_.empty()) {
+ return test_runner_->Invalidate("Too many inputs to reshape");
+ }
+ test_runner_->ReshapeTensor(*expected_inputs_.begin(), value);
+ expected_inputs_.erase(expected_inputs_.begin());
+ }
+ }
+
+ private:
+ std::vector<int> expected_inputs_;
+ TestRunner* test_runner_;
+};
+
+// This is the top-level message in a test file.
+class TestData : public Message {
+ public:
+ explicit TestData(TestRunner* test_runner) : test_runner_(test_runner) {}
+
+ void SetField(const std::string& name, const std::string& value) override {
+ if (name == "load_model") {
+ test_runner_->LoadModel(value);
+ } else if (name == "init_state") {
+ test_runner_->AllocateTensors();
+ for (int id : Split<int>(value, ",")) {
+ test_runner_->ResetTensor(id);
+ }
+ }
+ }
+ Message* AddChild(const std::string& s) override {
+ if (s == "invoke") {
+ test_runner_->AllocateTensors();
+ return Store(new Invoke(test_runner_));
+ } else if (s == "reshape") {
+ return Store(new Reshape(test_runner_));
+ }
+ return nullptr;
+ }
+
+ private:
+ TestRunner* test_runner_;
+};
+
+bool ParseAndRunTests(std::istream* input, TestRunner* test_runner) {
+ TestData test_data(test_runner);
+ Message::Read(input, &test_data);
+ return test_runner->IsValid() && test_runner->GetOverallSuccess();
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/parse_testdata.h b/tensorflow/contrib/lite/testing/parse_testdata.h
new file mode 100644
index 0000000000..90839fe245
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/parse_testdata.h
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
+
+#include <vector>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/testing/test_runner.h"
+
+namespace tflite {
+namespace testing {
+
+// Shape and data for a float tensor
+struct FloatTensor {
+ std::vector<int> shape;
+ std::vector<float> flat_data;
+};
+
+// A prescribed input, output example
+struct Example {
+ std::vector<FloatTensor> inputs;
+ std::vector<FloatTensor> outputs;
+};
+
+// Parses an example input and output file (used for unit tests)
+TfLiteStatus ParseExamples(const char* filename,
+ std::vector<Example>* examples);
+
+// Inputs Tensors into a TensorFlow lite interpreter. Note, this will run
+// interpreter.AllocateTensors();
+TfLiteStatus FeedExample(tflite::Interpreter* interpreter, const Example&);
+
+// Check outputs against (already) evaluated result.
+TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, const Example&);
+
+// Parses a test description and feeds the given test runner with data.
+// The input format is similar to an ASCII proto:
+// // Loads model 'add.bin' from the TestRunner's model directory.
+// load_model: "add.bin"
+// // Changes the shape of inputs, provided in the same order they appear
+// // in the model.
+// reshape {
+// input: "1,224,224,3"
+// input: "1,3,4,1"
+// }
+// // Fills the given persistent tensors with zeros.
+// init_state: 0,1,2,3
+// // Invokes the interpreter with the given input and checks that it
+// // produces the expected output. Inputs and outputs should be specified in
+// // the order they appear in the model.
+// invoke {
+// input: "1,2,3,4,56"
+// input: "0.1,0.2,0.3,4.3,56.4"
+// output: "12,3,4,545,3"
+// output: "0.01,0.02"
+// }
+bool ParseAndRunTests(std::istream* input, TestRunner* test_runner);
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
diff --git a/tensorflow/contrib/lite/testing/split.cc b/tensorflow/contrib/lite/testing/split.cc
new file mode 100644
index 0000000000..5836f4ff04
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/split.cc
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/split.h"
+
+namespace tflite {
+namespace testing {
+
+std::vector<std::pair<size_t, size_t>> SplitToPos(const string& s,
+ const string& delimiter) {
+ std::vector<std::pair<size_t, size_t>> fields;
+ if (delimiter.length() == 0) {
+ fields.emplace_back(0, s.length());
+ return fields;
+ }
+ size_t pos = 0;
+ size_t start = 0;
+ while ((pos = s.find(delimiter, start)) != string::npos) {
+ if (pos != start) {
+ fields.emplace_back(start, pos);
+ }
+ start = pos + delimiter.length();
+ }
+ if (start != s.length()) {
+ fields.emplace_back(start, s.length());
+ }
+ return fields;
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/split.h b/tensorflow/contrib/lite/testing/split.h
new file mode 100644
index 0000000000..24071442e8
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/split.h
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
+
+#include <cstdlib>
+#include <string>
+#include <utility>
+#include <vector>
+#include "tensorflow/contrib/lite/string.h"
+
+namespace tflite {
+namespace testing {
+
+// Splits a string based on the given delimiter string. Each pair in the
+// returned vector has the start and past-the-end positions for each of the
+// parts of the original string. Empty fields are not represented in the
+// output.
+std::vector<std::pair<size_t, size_t>> SplitToPos(const string& s,
+ const string& delimiter);
+
+// Splits the given string and converts each part to the given T.
+template <typename T>
+std::vector<T> Split(const string& s, const string& delimiter);
+
+template <>
+inline std::vector<string> Split(const string& s, const string& delimiter) {
+ std::vector<string> fields;
+ for (const auto& p : SplitToPos(s, delimiter)) {
+ fields.push_back(s.substr(p.first, p.second - p.first));
+ }
+ return fields;
+}
+
+template <>
+inline std::vector<int> Split(const string& s, const string& delimiter) {
+ std::vector<int> fields;
+ for (const auto& p : SplitToPos(s, delimiter)) {
+ fields.push_back(strtol(s.data() + p.first, nullptr, 10));
+ }
+ return fields;
+}
+
+template <>
+inline std::vector<float> Split(const string& s, const string& delimiter) {
+ std::vector<float> fields;
+ for (const auto& p : SplitToPos(s, delimiter)) {
+ fields.push_back(strtod(s.data() + p.first, nullptr));
+ }
+ return fields;
+}
+
+template <>
+inline std::vector<uint8_t> Split(const string& s, const string& delimiter) {
+ std::vector<uint8_t> fields;
+ for (const auto& p : SplitToPos(s, delimiter)) {
+ fields.push_back(strtol(s.data() + p.first, nullptr, 10));
+ }
+ return fields;
+}
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
diff --git a/tensorflow/contrib/lite/testing/split_test.cc b/tensorflow/contrib/lite/testing/split_test.cc
new file mode 100644
index 0000000000..3d1e25d9c7
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/split_test.cc
@@ -0,0 +1,57 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/split.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Pair;
+
+TEST(SplitTest, SplitToPos) {
+ EXPECT_THAT(SplitToPos("test;:1-2-3 ;: test", ";:"),
+ ElementsAre(Pair(0, 4), Pair(6, 12), Pair(14, 19)));
+ EXPECT_THAT(SplitToPos("test;:1-2-3 ;: test", ":"),
+ ElementsAre(Pair(0, 5), Pair(6, 13), Pair(14, 19)));
+ EXPECT_THAT(SplitToPos("test", ":"), ElementsAre(Pair(0, 4)));
+ EXPECT_THAT(SplitToPos("test ", ":"), ElementsAre(Pair(0, 5)));
+ EXPECT_THAT(SplitToPos("", ":"), ElementsAre());
+ EXPECT_THAT(SplitToPos("test ", ""), ElementsAre(Pair(0, 5)));
+ EXPECT_THAT(SplitToPos("::::", ":"), ElementsAre());
+}
+
+TEST(SplitTest, SplitString) {
+ EXPECT_THAT(Split<string>("A;B;C", ";"), ElementsAre("A", "B", "C"));
+}
+
+TEST(SplitTest, SplitFloat) {
+ EXPECT_THAT(Split<float>("1.0 B 1e-5", " "), ElementsAre(1.0, 0.0, 1e-5));
+}
+
+TEST(SplitTest, SplitInt) {
+ EXPECT_THAT(Split<int>("1,-1,258", ","), ElementsAre(1, -1, 258));
+}
+
+TEST(SplitTest, SplitUint8) {
+ EXPECT_THAT(Split<uint8_t>("1,-1,258", ","), ElementsAre(1, 255, 2));
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h
new file mode 100644
index 0000000000..04ee4d9f7d
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/test_runner.h
@@ -0,0 +1,124 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+#include "tensorflow/contrib/lite/string.h"
+
+namespace tflite {
+namespace testing {
+
+// This is the base class for processing test data. Each one of the virtual
+// methods must be implemented to forward the data to the appropriate executor
+// (e.g. TF Lite's interpreter, or the NNAPI).
+class TestRunner {
+ public:
+ TestRunner() {}
+ virtual ~TestRunner() {}
+
+ // Load the given model, as a path relative to SetModelBaseDir().
+ virtual void LoadModel(const string& bin_file_path) = 0;
+
+ // Return the list of input tensors in the loaded model.
+ virtual const std::vector<int>& GetInputs() = 0;
+
+ // Return the list of output tensors in the loaded model.
+ virtual const std::vector<int>& GetOutputs() = 0;
+
+ // Prepare for a run by resize the given tensor. The given 'id' is
+ // guaranteed to be one of the ids returned by GetInputs().
+ virtual void ReshapeTensor(int id, const string& csv_values) = 0;
+
+ // Reserve memory for all tensors.
+ virtual void AllocateTensors() = 0;
+
+ // Set the given tensor to some initial state, usually zero. This is
+ // used to reset persistent buffers in a model.
+ virtual void ResetTensor(int id) = 0;
+
+ // Define the contents of the given input tensor. The given 'id' is
+ // guaranteed to be one of the ids returned by GetInputs().
+ virtual void SetInput(int id, const string& csv_values) = 0;
+
+ // Define what should be expected for an output tensor after Invoke() runs.
+ // The given 'id' is guaranteed to be one of the ids returned by
+ // GetOutputs().
+ virtual void SetExpectation(int id, const string& csv_values) = 0;
+
+ // Run the model.
+ virtual void Invoke() = 0;
+
+ // Verify that the contents of all ouputs conform to the existing
+ // expectations. Return true if there are no expectations or they are all
+ // satisfied.
+ virtual bool CheckResults() = 0;
+
+ // Set the base path for loading models.
+ void SetModelBaseDir(const string& path) {
+ model_base_dir_ = path;
+ if (path[path.length() - 1] != '/') {
+ model_base_dir_ += "/";
+ }
+ }
+
+ // Return the full path of a model.
+ string GetFullPath(const string& path) { return model_base_dir_ + path; }
+
+ // Give an id to the next invocation to make error reporting more meaningful.
+ void SetInvocationId(const string& id) { invocation_id_ = id; }
+ const string& GetInvocationId() const { return invocation_id_; }
+
+ // Invalidate the test runner, preventing it from executing any further.
+ void Invalidate(const string& error_message) {
+ error_message_ = error_message;
+ }
+ bool IsValid() const { return error_message_.empty(); }
+ const string& GetErrorMessage() const { return error_message_; }
+
+ // Handle the overall success of this test runner. This will be true if all
+ // invocations were successful.
+ void SetOverallSuccess(bool value) { overall_success_ = value; }
+ bool GetOverallSuccess() const { return overall_success_; }
+
+ protected:
+ // A helper to check of the given number of values is consistent with the
+ // number of bytes in a tensor of type T. When incompatibles sizes are found,
+ // the test runner is invalidated and false is returned.
+ template <typename T>
+ bool CheckSizes(size_t tensor_bytes, size_t num_values) {
+ size_t num_tensor_elements = tensor_bytes / sizeof(T);
+ if (num_tensor_elements != num_values) {
+ Invalidate("Expected '" + std::to_string(num_tensor_elements) +
+ "' elements for a tensor, but only got '" +
+ std::to_string(num_values) + "'");
+ return false;
+ }
+ return true;
+ }
+
+ private:
+ string model_base_dir_;
+ string invocation_id_;
+ bool overall_success_ = true;
+
+ string error_message_;
+};
+
+} // namespace testing
+} // namespace tflite
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
diff --git a/tensorflow/contrib/lite/testing/test_runner_test.cc b/tensorflow/contrib/lite/testing/test_runner_test.cc
new file mode 100644
index 0000000000..f712a5347a
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/test_runner_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/test_runner.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+class ConcreteTestRunner : public TestRunner {
+ public:
+ void LoadModel(const string& bin_file_path) override {}
+ const std::vector<int>& GetInputs() override { return ids_; }
+ const std::vector<int>& GetOutputs() override { return ids_; }
+ void ReshapeTensor(int id, const string& csv_values) override {}
+ void AllocateTensors() override {}
+ void ResetTensor(int id) override {}
+ void SetInput(int id, const string& csv_values) override {}
+ void SetExpectation(int id, const string& csv_values) override {}
+ void Invoke() override {}
+ bool CheckResults() override { return true; }
+ bool CheckFloatSizes(size_t bytes, size_t values) {
+ return CheckSizes<float>(bytes, values);
+ }
+
+ private:
+ std::vector<int> ids_;
+};
+
+TEST(TestRunner, ModelPath) {
+ ConcreteTestRunner runner;
+ EXPECT_EQ(runner.GetFullPath("test.bin"), "test.bin");
+ runner.SetModelBaseDir("/tmp");
+ EXPECT_EQ(runner.GetFullPath("test.bin"), "/tmp/test.bin");
+}
+
+TEST(TestRunner, InvocationId) {
+ ConcreteTestRunner runner;
+ EXPECT_EQ(runner.GetInvocationId(), "");
+ runner.SetInvocationId("X");
+ EXPECT_EQ(runner.GetInvocationId(), "X");
+}
+
+TEST(TestRunner, Invalidation) {
+ ConcreteTestRunner runner;
+ EXPECT_TRUE(runner.IsValid());
+ EXPECT_EQ(runner.GetErrorMessage(), "");
+ runner.Invalidate("Some Error");
+ EXPECT_FALSE(runner.IsValid());
+ EXPECT_EQ(runner.GetErrorMessage(), "Some Error");
+}
+
+TEST(TestRunner, OverallSuccess) {
+ ConcreteTestRunner runner;
+ EXPECT_TRUE(runner.GetOverallSuccess());
+ runner.SetOverallSuccess(false);
+ EXPECT_FALSE(runner.GetOverallSuccess());
+}
+
+TEST(TestRunner, CheckSizes) {
+ ConcreteTestRunner runner;
+ EXPECT_TRUE(runner.CheckFloatSizes(16, 4));
+ EXPECT_FALSE(runner.CheckFloatSizes(16, 2));
+ EXPECT_EQ(runner.GetErrorMessage(),
+ "Expected '4' elements for a tensor, but only got '2'");
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
new file mode 100644
index 0000000000..cf9df2ec26
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -0,0 +1,208 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tflite_driver.h"
+
+#include <iostream>
+
+#include "tensorflow/contrib/lite/testing/split.h"
+
+namespace tflite {
+namespace testing {
+
+namespace {
+
+// Returns the value in the given position in a tensor.
+template <typename T>
+T Value(const TfLitePtrUnion& data, int index);
+template <>
+float Value(const TfLitePtrUnion& data, int index) {
+ return data.f[index];
+}
+template <>
+uint8_t Value(const TfLitePtrUnion& data, int index) {
+ return data.uint8[index];
+}
+
+template <typename T>
+void SetTensorData(const std::vector<T>& values, TfLitePtrUnion* data) {
+ T* input_ptr = reinterpret_cast<T*>(data->raw);
+ for (const T& v : values) {
+ *input_ptr = v;
+ ++input_ptr;
+ }
+}
+
+} // namespace
+
+class TfLiteDriver::Expectation {
+ public:
+ Expectation() { data_.raw = nullptr; }
+ ~Expectation() { delete[] data_.raw; }
+ template <typename T>
+ void SetData(const string& csv_values) {
+ const auto& values = testing::Split<T>(csv_values, ",");
+ data_.raw = new char[values.size() * sizeof(T)];
+ SetTensorData(values, &data_);
+ }
+
+ bool Check(bool verbose, const TfLiteTensor& tensor) {
+ switch (tensor.type) {
+ case kTfLiteFloat32:
+ return TypedCheck<float>(verbose, tensor);
+ case kTfLiteUInt8:
+ return TypedCheck<uint8_t>(verbose, tensor);
+ default:
+ return false;
+ }
+ }
+
+ private:
+ template <typename T>
+ bool TypedCheck(bool verbose, const TfLiteTensor& tensor) {
+ int tensor_size = tensor.bytes / sizeof(T);
+
+ bool good_output = true;
+ for (int i = 0; i < tensor_size; ++i) {
+ if (std::abs(Value<T>(data_, i) - Value<T>(tensor.data, i)) > 1e-5) {
+ good_output = false;
+ if (verbose) {
+ std::cerr << " index " << i << ": " << Value<T>(data_, i)
+ << " != " << Value<T>(tensor.data, i) << std::endl;
+ }
+ }
+ }
+ return good_output;
+ }
+
+ TfLitePtrUnion data_;
+};
+
+TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {}
+TfLiteDriver::~TfLiteDriver() {}
+
+void TfLiteDriver::AllocateTensors() {
+ if (must_allocate_tensors_) {
+ if (interpreter_->AllocateTensors() != kTfLiteOk) {
+ std::cerr << "Failed to allocate tensors" << std::endl;
+ abort();
+ }
+ must_allocate_tensors_ = false;
+ }
+}
+
+void TfLiteDriver::LoadModel(const string& bin_file_path) {
+ if (!IsValid()) return;
+ std::cout << std::endl << "Loading model: " << bin_file_path << std::endl;
+
+ model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str());
+ if (!model_) {
+ Invalidate("Failed to mmap model " + bin_file_path);
+ return;
+ }
+ ops::builtin::BuiltinOpResolver builtins;
+ InterpreterBuilder(*model_, builtins)(&interpreter_);
+ if (!interpreter_) {
+ Invalidate("Failed build interpreter");
+ return;
+ }
+
+ must_allocate_tensors_ = true;
+}
+
+void TfLiteDriver::ResetTensor(int id) {
+ if (!IsValid()) return;
+ auto* tensor = interpreter_->tensor(id);
+ memset(tensor->data.raw, 0, tensor->bytes);
+}
+
+void TfLiteDriver::ReshapeTensor(int id, const string& csv_values) {
+ if (!IsValid()) return;
+ if (interpreter_->ResizeInputTensor(
+ id, testing::Split<int>(csv_values, ",")) != kTfLiteOk) {
+ Invalidate("Failed to resize input tensor " + std::to_string(id));
+ return;
+ }
+ must_allocate_tensors_ = true;
+}
+
+void TfLiteDriver::SetInput(int id, const string& csv_values) {
+ if (!IsValid()) return;
+ auto* tensor = interpreter_->tensor(id);
+ switch (tensor->type) {
+ case kTfLiteFloat32: {
+ const auto& values = testing::Split<float>(csv_values, ",");
+ if (!CheckSizes<float>(tensor->bytes, values.size())) return;
+ SetTensorData(values, &tensor->data);
+ break;
+ }
+ case kTfLiteUInt8: {
+ const auto& values = testing::Split<uint8_t>(csv_values, ",");
+ if (!CheckSizes<uint8_t>(tensor->bytes, values.size())) return;
+ SetTensorData(values, &tensor->data);
+ break;
+ }
+ default:
+ Invalidate("Unsupported tensor data type");
+ return;
+ }
+}
+
+void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
+ if (!IsValid()) return;
+ auto* tensor = interpreter_->tensor(id);
+ expected_output_[id].reset(new Expectation);
+ switch (tensor->type) {
+ case kTfLiteFloat32:
+ expected_output_[id]->SetData<float>(csv_values);
+ break;
+ case kTfLiteUInt8:
+ expected_output_[id]->SetData<uint8_t>(csv_values);
+ break;
+ default:
+ Invalidate("Unsupported tensor data type");
+ return;
+ }
+}
+
+void TfLiteDriver::Invoke() {
+ if (!IsValid()) return;
+ if (interpreter_->Invoke() != kTfLiteOk) {
+ Invalidate("Failed to invoke interpreter");
+ }
+}
+
+bool TfLiteDriver::CheckResults() {
+ if (!IsValid()) return false;
+ bool success = true;
+ for (const auto& p : expected_output_) {
+ int id = p.first;
+ auto* tensor = interpreter_->tensor(id);
+ if (!p.second->Check(/*verbose=*/false, *tensor)) {
+ // Do not invalidate anything here. Instead, simply output the
+ // differences and return false. Invalidating would prevent all
+ // subsequent invocations from running..
+ std::cerr << "There were errors in invocation '" << GetInvocationId()
+ << "', output tensor '" << id << "':" << std::endl;
+ p.second->Check(/*verbose=*/true, *tensor);
+ success = false;
+ SetOverallSuccess(false);
+ }
+ }
+ expected_output_.clear();
+ return success;
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h
new file mode 100644
index 0000000000..4440d4285e
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tflite_driver.h
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
+
+#include <map>
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/testing/test_runner.h"
+
+namespace tflite {
+namespace testing {
+
+// A test runner that feeds inputs into TF Lite and verifies its outputs.
+class TfLiteDriver : public TestRunner {
+ public:
+ explicit TfLiteDriver(bool use_nnapi);
+ ~TfLiteDriver() override;
+
+ void LoadModel(const string& bin_file_path) override;
+ const std::vector<int>& GetInputs() override {
+ return interpreter_->inputs();
+ }
+ const std::vector<int>& GetOutputs() override {
+ return interpreter_->outputs();
+ }
+ void ReshapeTensor(int id, const string& csv_values) override;
+ void AllocateTensors() override;
+ void ResetTensor(int id) override;
+ void SetInput(int id, const string& csv_values) override;
+ void SetExpectation(int id, const string& csv_values) override;
+ void Invoke() override;
+ bool CheckResults() override;
+
+ private:
+ class Expectation;
+
+ bool use_nnapi_ = false;
+ std::unique_ptr<FlatBufferModel> model_;
+ std::unique_ptr<Interpreter> interpreter_;
+ std::map<int, std::unique_ptr<Expectation>> expected_output_;
+ bool must_allocate_tensors_ = true;
+};
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
diff --git a/tensorflow/contrib/lite/testing/tflite_driver_test.cc b/tensorflow/contrib/lite/testing/tflite_driver_test.cc
new file mode 100644
index 0000000000..37010c468f
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tflite_driver_test.cc
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tflite_driver.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+using ::testing::ElementsAre;
+
+TEST(TfliteDriverTest, SimpleTest) {
+ std::unique_ptr<TestRunner> runner(new TfLiteDriver(/*use_nnapi=*/false));
+
+ runner->SetModelBaseDir("tensorflow/contrib/lite");
+ runner->LoadModel("testdata/multi_add.bin");
+ ASSERT_TRUE(runner->IsValid());
+
+ ASSERT_THAT(runner->GetInputs(), ElementsAre(0, 1, 2, 3));
+ ASSERT_THAT(runner->GetOutputs(), ElementsAre(5, 6));
+
+ for (int i : {0, 1, 2, 3}) {
+ runner->ReshapeTensor(i, "1,2,2,1");
+ }
+ ASSERT_TRUE(runner->IsValid());
+
+ runner->AllocateTensors();
+
+ runner->SetInput(0, "0.1,0.2,0.3,0.4");
+ runner->SetInput(1, "0.001,0.002,0.003,0.004");
+ runner->SetInput(2, "0.001,0.002,0.003,0.004");
+ runner->SetInput(3, "0.01,0.02,0.03,0.04");
+
+ runner->ResetTensor(2);
+
+ runner->SetExpectation(5, "0.101,0.202,0.303,0.404");
+ runner->SetExpectation(6, "0.011,0.022,0.033,0.044");
+
+ runner->Invoke();
+ ASSERT_TRUE(runner->IsValid());
+
+ ASSERT_TRUE(runner->CheckResults());
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tokenize.cc b/tensorflow/contrib/lite/testing/tokenize.cc
new file mode 100644
index 0000000000..2e84ea475c
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tokenize.cc
@@ -0,0 +1,95 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tokenize.h"
+#include <istream>
+#include <string>
+#include "tensorflow/contrib/lite/string.h"
+
+namespace tflite {
+namespace testing {
+
+void Tokenize(std::istream* input, TokenProcessor* processor) {
+ enum State { kBuildQuotedToken, kBuildToken, kIdle };
+
+ std::string current_token;
+ State state = kIdle;
+ auto start_token = [&](char c) {
+ state = kBuildToken;
+ current_token.clear();
+ current_token = c;
+ };
+ auto issue_token = [&]() {
+ state = kIdle;
+ processor->ConsumeToken(&current_token);
+ current_token.clear();
+ };
+ auto start_quoted_token = [&]() {
+ state = kBuildQuotedToken;
+ current_token.clear();
+ };
+ auto issue_quoted_token = [&]() {
+ state = kIdle;
+ processor->ConsumeToken(&current_token);
+ current_token.clear();
+ };
+ auto issue_delim = [&](char d) {
+ current_token = string(1, d);
+ processor->ConsumeToken(&current_token);
+ current_token.clear();
+ };
+ auto is_delim = [](char c) { return c == '{' || c == '}' || c == ':'; };
+ auto is_quote = [](char c) { return c == '"'; };
+
+ for (auto it = std::istreambuf_iterator<char>(*input);
+ it != std::istreambuf_iterator<char>(); ++it) {
+ switch (state) {
+ case kIdle:
+ if (is_delim(*it)) {
+ issue_delim(*it);
+ } else if (is_quote(*it)) {
+ start_quoted_token();
+ } else if (!isspace(*it)) {
+ start_token(*it);
+ }
+ break;
+ case kBuildToken:
+ if (is_delim(*it)) {
+ issue_token();
+ issue_delim(*it);
+ } else if (is_quote(*it)) {
+ issue_token();
+ start_quoted_token();
+ } else if (isspace(*it)) {
+ issue_token();
+ } else {
+ current_token += *it;
+ }
+ break;
+ case kBuildQuotedToken:
+ if (is_quote(*it)) {
+ issue_quoted_token();
+ } else {
+ current_token += *it;
+ }
+ break;
+ }
+ }
+ if (state != kIdle) {
+ issue_token();
+ }
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tokenize.h b/tensorflow/contrib/lite/testing/tokenize.h
new file mode 100644
index 0000000000..daccf0e84a
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tokenize.h
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
+
+#include <istream>
+#include <string>
+
+namespace tflite {
+namespace testing {
+
+// Process tokens coming from Tokenize().
+class TokenProcessor {
+ public:
+ virtual ~TokenProcessor() {}
+ // Process a single token. The token won't be reused, so it is OK to call
+ // token.swap().
+ virtual void ConsumeToken(std::string* token) = 0;
+};
+
+// Tokenize a stream on whitespaces, colons and curly braces. Whitespaces are
+// removed from the tokens and double-quotes can be used to avoid that. Note
+// that there is no way to escape double-quotes, so there's no way to have a
+// double-quote inside a token.
+void Tokenize(std::istream* input, TokenProcessor* processor);
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
diff --git a/tensorflow/contrib/lite/testing/tokenize_test.cc b/tensorflow/contrib/lite/testing/tokenize_test.cc
new file mode 100644
index 0000000000..80f44aacca
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tokenize_test.cc
@@ -0,0 +1,105 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tokenize.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+class TokenCollector : public TokenProcessor {
+ public:
+ void ConsumeToken(std::string* token) override { tokens_.push_back(*token); }
+ const std::vector<std::string>& Tokens() { return tokens_; }
+
+ private:
+ std::vector<std::string> tokens_;
+};
+
+std::vector<std::string> TokenizeString(const std::string& s) {
+ std::stringstream ss(s);
+ TokenCollector collector;
+ Tokenize(&ss, &collector);
+ return collector.Tokens();
+}
+
+TEST(TokenizeTest, TokenDetection) {
+ EXPECT_THAT(TokenizeString("x :1"), ElementsAre("x", ":", "1"));
+ EXPECT_THAT(TokenizeString("x:1"), ElementsAre("x", ":", "1"));
+ EXPECT_THAT(TokenizeString("x {1"), ElementsAre("x", "{", "1"));
+ EXPECT_THAT(TokenizeString("x{1"), ElementsAre("x", "{", "1"));
+ EXPECT_THAT(TokenizeString("x }1"), ElementsAre("x", "}", "1"));
+ EXPECT_THAT(TokenizeString("x}1"), ElementsAre("x", "}", "1"));
+ EXPECT_THAT(TokenizeString("x \"1"), ElementsAre("x", "1"));
+ EXPECT_THAT(TokenizeString("x\"1"), ElementsAre("x", "1"));
+}
+
+TEST(TokenizeTest, QuotedTokenDetection) {
+ EXPECT_THAT(TokenizeString("\"w:x{y}z\"1"), ElementsAre("w:x{y}z", "1"));
+ EXPECT_THAT(TokenizeString("\"w:x{y}z\"\"1\""), ElementsAre("w:x{y}z", "1"));
+}
+
+TEST(TokenizeTest, Delimiters) {
+ EXPECT_THAT(TokenizeString("}"), ElementsAre("}"));
+ EXPECT_THAT(TokenizeString("}}"), ElementsAre("}", "}"));
+ EXPECT_THAT(TokenizeString("{"), ElementsAre("{"));
+ EXPECT_THAT(TokenizeString("{{"), ElementsAre("{", "{"));
+ EXPECT_THAT(TokenizeString(":"), ElementsAre(":"));
+ EXPECT_THAT(TokenizeString("::"), ElementsAre(":", ":"));
+}
+
+TEST(TokenizeTest, CornerCases) {
+ EXPECT_THAT(TokenizeString(" i { b:a } "),
+ ElementsAre("i", "{", "b", ":", "a", "}"));
+ EXPECT_THAT(TokenizeString(" }"), ElementsAre("}"));
+ EXPECT_THAT(TokenizeString(" } "), ElementsAre("}"));
+ EXPECT_THAT(TokenizeString(" {} "), ElementsAre("{", "}"));
+ EXPECT_THAT(TokenizeString(" x{} y{} "),
+ ElementsAre("x", "{", "}", "y", "{", "}"));
+ EXPECT_THAT(TokenizeString("x:1 y:2 "),
+ ElementsAre("x", ":", "1", "y", ":", "2"));
+ EXPECT_THAT(TokenizeString("x:\"1\" y:2 "),
+ ElementsAre("x", ":", "1", "y", ":", "2"));
+ EXPECT_THAT(TokenizeString("x:\"1, 2\" y:\"\" "),
+ ElementsAre("x", ":", "1, 2", "y", ":", ""));
+}
+
+TEST(TokenizeTest, NewLines) {
+ EXPECT_THAT(TokenizeString("x:\n1,\n 2 \n y :\n3 \n"),
+ ElementsAre("x", ":", "1,", "2", "y", ":", "3"));
+}
+
+TEST(TokenizeTest, LongString) {
+ EXPECT_THAT(
+ TokenizeString(" i { b:a } input {"
+ "a: \"1e-1, 2,3\" b:\"1,2,3\"\n c{ "
+ "id:1 x{d{a:"
+ "1}}} f:2 "
+ "\n}\n t:1"),
+ ElementsAreArray({"i", "{", "b", ":", "a", "}", "input", "{",
+ "a", ":", "1e-1, 2,3", "b", ":", "1,2,3", "c", "{",
+ "id", ":", "1", "x", "{", "d", "{", "a",
+ ":", "1", "}", "}", "}", "f", ":", "2",
+ "}", "t", ":", "1"}));
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
new file mode 100644
index 0000000000..05e77c330c
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -0,0 +1,350 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_proto_library_cc",
+ "tf_proto_library_py",
+)
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_binary",
+ "tf_cc_test",
+)
+
+tf_proto_library_cc(
+ name = "toco_flags_proto",
+ srcs = ["toco_flags.proto"],
+ visibility = ["//visibility:public"],
+)
+
+tf_proto_library_cc(
+ name = "model_flags_proto",
+ srcs = ["model_flags.proto"],
+ visibility = ["//visibility:public"],
+)
+
+tf_proto_library_py(
+ name = "toco_flags_proto",
+ srcs = [
+ "toco_flags.proto",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+tf_proto_library_py(
+ name = "model_flags_proto",
+ srcs = [
+ "model_flags.proto",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "tensorflow_core_cc_protos_all",
+ deps = ["//tensorflow/core:protos_all_cc"],
+)
+
+cc_library(
+ name = "runtime",
+ hdrs = [
+ "runtime/common.h",
+ "runtime/types.h",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:types",
+ ],
+)
+
+# :model offers the core data structures representing a model (a.k.a. "graph")
+# for tooling purposes (not needed at inference runtime).
+# That includes the top-level Model structure, and the lower-level Operator,
+# Array, Buffer structures, etc.
+cc_library(
+ name = "model",
+ hdrs = [
+ "model.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model_flags_proto_cc",
+ ":runtime",
+ ":toco_port",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/base:core_headers",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+cc_library(
+ name = "toco_graphviz_dump_options",
+ srcs = [
+ "toco_graphviz_dump_options.cc",
+ ],
+ hdrs = [
+ "toco_graphviz_dump_options.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "toco_cmdline_flags",
+ srcs = [
+ "toco_cmdline_flags.cc",
+ ],
+ hdrs = [
+ "toco_cmdline_flags.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model_cmdline_flags",
+ ":toco_flags_proto_cc",
+ ":toco_port",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "model_cmdline_flags",
+ srcs = [
+ "model_cmdline_flags.cc",
+ ],
+ hdrs = [
+ "args.h",
+ "model_cmdline_flags.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model_flags_proto_cc",
+ ":toco_graphviz_dump_options",
+ ":toco_port",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "toco_port",
+ srcs = [
+ "toco_port.cc",
+ ],
+ hdrs = [
+ "format_port.h",
+ "toco_port.h",
+ "toco_types.h",
+ ],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ] + select({
+ "//tensorflow:android": [],
+ "//tensorflow:darwin": [],
+ "//tensorflow:ios": [],
+ "//conditions:default": [],
+ "//tensorflow:dummy_disabled_internal": [],
+ }),
+)
+
+cc_library(
+ name = "graph_transformations",
+ srcs = [
+ "graph_transformations/convert_pure_conv_to_depthwise.cc",
+ "graph_transformations/create_im2col_arrays.cc",
+ "graph_transformations/dequantize.cc",
+ "graph_transformations/drop_fake_quant.cc",
+ "graph_transformations/drop_im2col_arrays.cc",
+ "graph_transformations/ensure_bias_vectors.cc",
+ "graph_transformations/fuse_activation_functions.cc",
+ "graph_transformations/fuse_binary_into_following_affine.cc",
+ "graph_transformations/fuse_binary_into_preceding_affine.cc",
+ "graph_transformations/graph_transformations.cc",
+ "graph_transformations/hardcode_min_max.cc",
+ "graph_transformations/identify_l2_normalization.cc",
+ "graph_transformations/identify_l2_pool.cc",
+ "graph_transformations/identify_lstm.cc",
+ "graph_transformations/identify_relu1.cc",
+ "graph_transformations/make_initial_dequantize_operator.cc",
+ "graph_transformations/propagate_array_data_types.cc",
+ "graph_transformations/propagate_fixed_sizes.cc",
+ "graph_transformations/quantize.cc",
+ "graph_transformations/read_fake_quant_min_max.cc",
+ "graph_transformations/remove_final_dequantize_op.cc",
+ "graph_transformations/remove_tensorflow_assert.cc",
+ "graph_transformations/remove_tensorflow_identity.cc",
+ "graph_transformations/remove_trivial_binary.cc",
+ "graph_transformations/remove_trivial_concatenation.cc",
+ "graph_transformations/remove_trivial_concatenation_input.cc",
+ "graph_transformations/remove_trivial_passthrough.cc",
+ "graph_transformations/remove_trivial_passthrough.h",
+ "graph_transformations/remove_trivial_quantized_activation_func.cc",
+ "graph_transformations/remove_trivial_reshape.cc",
+ "graph_transformations/remove_unused_op.cc",
+ "graph_transformations/resolve_batch_normalization.cc",
+ "graph_transformations/resolve_constant_binary.cc",
+ "graph_transformations/resolve_constant_concatenation.cc",
+ "graph_transformations/resolve_constant_fake_quant.cc",
+ "graph_transformations/resolve_constant_tensorflow_shape.cc",
+ "graph_transformations/resolve_constant_unary.cc",
+ "graph_transformations/resolve_mean_attributes.cc",
+ "graph_transformations/resolve_pad_attributes.cc",
+ "graph_transformations/resolve_reorder_axes.cc",
+ "graph_transformations/resolve_reshape_attributes.cc",
+ "graph_transformations/resolve_slice_attributes.cc",
+ "graph_transformations/resolve_strided_slice_attributes.cc",
+ "graph_transformations/resolve_tensorflow_concat.cc",
+ "graph_transformations/resolve_tensorflow_matmul.cc",
+ "graph_transformations/resolve_tensorflow_merge.cc",
+ "graph_transformations/resolve_tensorflow_squeeze.cc",
+ "graph_transformations/resolve_tensorflow_switch.cc",
+ "graph_transformations/resolve_tensorflow_tile.cc",
+ "graph_transformations/unfuse_activation_functions.cc",
+ ],
+ hdrs = [
+ "graph_transformations/graph_transformations.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model",
+ ":model_flags_proto_cc",
+ ":runtime",
+ ":toco_port",
+ ":tooling_util",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+# :toco_tooling is the library providing the offline tooling functionality
+# exposed by the :toco command-line tool.
+cc_library(
+ name = "toco_tooling",
+ srcs = [
+ "allocate_transient_arrays.cc",
+ "export_tensorflow.cc",
+ "import_tensorflow.cc",
+ "tensorflow_util.cc",
+ "toco_tooling.cc",
+ ],
+ hdrs = [
+ "allocate_transient_arrays.h",
+ "export_tensorflow.h",
+ "import_tensorflow.h",
+ "tensorflow_util.h",
+ "toco_tooling.h",
+ ],
+ copts = select({
+ "//tensorflow:darwin": ["-DTOCO_SUPPORT_PORTABLE_PROTOS=0"],
+ "//conditions:default": [],
+ }),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_transformations",
+ ":model",
+ ":model_flags_proto_cc",
+ ":runtime",
+ ":toco_graphviz_dump_options",
+ ":toco_flags_proto_cc",
+ ":toco_port",
+ ":tooling_util",
+ "@protobuf_archive//:protobuf_headers",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "//tensorflow/contrib/lite/toco/tensorflow_graph_matching:resolve_cluster",
+ "//tensorflow/contrib/lite/toco/tflite:export",
+ "//tensorflow/contrib/lite/toco/tflite:import",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ] + select({
+ # Placeholder for internal darwin rule.
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
+ name = "tooling_util",
+ srcs = [
+ "dump_graphviz.cc",
+ "tooling_util.cc",
+ ],
+ hdrs = [
+ "dump_graphviz.h",
+ "tooling_util.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model",
+ ":model_flags_proto_cc",
+ ":runtime",
+ ":toco_flags_proto_cc",
+ ":toco_graphviz_dump_options",
+ ":toco_port",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+tf_cc_test(
+ name = "tooling_util_test",
+ srcs = ["tooling_util_test.cc"],
+ deps = [
+ ":model",
+ ":tooling_util",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+# :toco is the main public command-line tool exposing the functionality
+# of the :toco_tooling library.
+tf_cc_binary(
+ name = "toco",
+ srcs = ["toco.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model",
+ ":model_cmdline_flags",
+ ":model_flags_proto_cc",
+ ":toco_cmdline_flags",
+ ":toco_flags_proto_cc",
+ ":toco_port",
+ ":toco_tooling",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+tf_cc_test(
+ name = "toco_port_test",
+ srcs = ["toco_port_test.cc"],
+ data = [
+ "toco_port_test.cc",
+ ],
+ deps = [
+ ":toco_port",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
new file mode 100644
index 0000000000..2f4454d7c8
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
@@ -0,0 +1,318 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <set>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/allocate_transient_arrays.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+namespace {
+
+// The life span of an array.
+struct ArrayLifespan {
+ // If true, the array is persistent state (as in a RNN). In that case,
+ // its allocation is permanent and the first_op, last_op members are
+ // unused. (The term 'transient' is a misnomer and we should think in
+ // terms of 'workspace' instead).
+ bool persistent = false;
+ // Index of the first op addressing that array. The array must be allocated
+ // just before executing this op.
+ std::size_t first_op = 0;
+ // Index of the last op addressing that array. We want to deallocate the array
+ // immediately after executing this op.
+ std::size_t last_op = 0;
+};
+
+bool StartsAt(const ArrayLifespan& lifespan, std::size_t op_index) {
+ return !lifespan.persistent && lifespan.first_op == op_index;
+}
+
+bool EndsAt(const ArrayLifespan& lifespan, std::size_t op_index) {
+ return !lifespan.persistent && lifespan.last_op == op_index;
+}
+
+// Helper function for ComputeArrayLifespans: updates one ArrayLifespan for
+// one array for one op.
+void UpdateArrayLifespan(
+ const string& array_name, std::size_t op_index,
+ std::unordered_map<string, ArrayLifespan>* array_lifespans) {
+ if (array_lifespans->count(array_name)) {
+ auto& lifespan = array_lifespans->at(array_name);
+ if (!lifespan.persistent) {
+ lifespan.first_op = std::min(lifespan.first_op, op_index);
+ lifespan.last_op = std::max(lifespan.last_op, op_index);
+ }
+ } else {
+ ArrayLifespan lifespan;
+ lifespan.first_op = op_index;
+ lifespan.last_op = op_index;
+ (*array_lifespans)[array_name] = lifespan;
+ }
+}
+
+// Computes the ArrayLifespan for each array.
+void ComputeArrayLifespans(
+ const Model& model,
+ std::unordered_map<string, ArrayLifespan>* array_lifespans) {
+ CHECK(array_lifespans->empty());
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ ArrayLifespan lifespan;
+ lifespan.persistent = true;
+ (*array_lifespans)[rnn_state.state_array()] = lifespan;
+ }
+ for (std::size_t op_index = 0; op_index < model.operators.size();
+ op_index++) {
+ const auto& op = model.operators[op_index];
+ for (const auto& input : op->inputs) {
+ UpdateArrayLifespan(input, op_index, array_lifespans);
+ }
+ for (const auto& output : op->outputs) {
+ UpdateArrayLifespan(output, op_index, array_lifespans);
+ }
+ }
+}
+
+inline bool operator==(const Alloc& a, const Alloc& b) {
+ CHECK(a.start != b.start || a.end == b.end);
+ return a.start == b.start;
+}
+
+// Helper to keep track of total allocation size and of currently live
+// allocations, and containing the core allocation routine.
+class Allocator {
+ public:
+ Allocator() : total_size_(0) {}
+
+ // Core allocation routine.
+ void Allocate(std::size_t size, Alloc* result) {
+ // Naive algorithm: pick the first gap between live allocations,
+ // that is wide enough for the new array.
+ std::size_t pos = 0;
+ for (const auto& a : live_allocs_) {
+ if (a.start >= pos + size) {
+ result->start = pos;
+ result->end = pos + size;
+ live_allocs_.insert(*result);
+ return;
+ }
+ pos = a.end;
+ }
+ // No sufficiently wide gap was found before an existing live allocation,
+ // so we allocate the new array at the end of the allocation space.
+ // We may then have to grow total_size_.
+ total_size_ = std::max(total_size_, pos + size);
+ result->start = pos;
+ result->end = pos + size;
+ live_allocs_.insert(*result);
+ }
+
+ void Deallocate(const Alloc& a) {
+ auto iter = std::lower_bound(live_allocs_.begin(), live_allocs_.end(), a);
+ CHECK(iter != live_allocs_.end());
+ CHECK(*iter == a);
+ live_allocs_.erase(iter);
+ }
+
+ std::size_t total_size() const { return total_size_; }
+
+ private:
+ std::size_t total_size_;
+ std::set<Alloc> live_allocs_;
+};
+
+// Returns the required transient allocation size (in bytes) for a given array,
+// or 0 if it's not a transient array.
+std::size_t TransientArraySize(const Model& model, const string& array_name,
+ std::size_t transient_data_alignment) {
+ if (!IsAllocatableTransientArray(model, array_name)) {
+ return 0;
+ }
+ const auto& array = model.arrays.at(array_name);
+ CHECK(array->has_shape())
+ << "Array '" << array_name << "' doesn't have a shape";
+ if (array->data_type == ArrayDataType::kNone) {
+ // Catch a typical issue at the moment with RNN states
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ if (rnn_state.state_array() == array_name) {
+ LOG(FATAL)
+ << "A RNN state array, " << array_name << ", still does not "
+ << "have a known data type after all graph transformations have "
+ << "run. That's mostly a toco bug --- sorry. For now, you can "
+ << "work around this issue by adding manually_create:true in the "
+ << "--rnn_state description of this RNN state.";
+ }
+ }
+ LOG(FATAL) << "An array, " << array_name << ", still does not "
+ << "have a known data type after all graph transformations have "
+ << "run.";
+ }
+ const std::size_t elem_size = ElementSize(array->data_type);
+ const std::size_t raw_size =
+ elem_size * RequiredBufferSizeForShape(array->shape());
+ const std::size_t rounded_size =
+ RoundUpToNextMultipleOf(raw_size, transient_data_alignment);
+ return rounded_size;
+}
+
+// Allocates an array: call this for every array just before the first
+// op where it is used.
+void AllocateTransientArray(const Model& model, const string& array_name,
+ Allocator* allocator,
+ std::size_t transient_data_alignment) {
+ if (!IsAllocatableTransientArray(model, array_name)) {
+ return;
+ }
+ const std::size_t size =
+ TransientArraySize(model, array_name, transient_data_alignment);
+ const auto& array = model.arrays.at(array_name);
+ CHECK(!array->alloc);
+ allocator->Allocate(size, &array->GetOrCreateAlloc());
+}
+
+// Deallocates an array: call this for every array just after the last
+// op where it is used.
+void DeallocateTransientArray(const Model& model, const string& array_name,
+ Allocator* allocator) {
+ if (!IsAllocatableTransientArray(model, array_name)) {
+ return;
+ }
+ const auto& array = model.arrays.at(array_name);
+ CHECK(!!array->alloc);
+ allocator->Deallocate(*array->alloc);
+}
+
+} // namespace
+
+void AllocateTransientArrays(Model* model,
+ std::size_t transient_data_alignment) {
+ // Precompute the lifespans for all arrays.
+ std::unordered_map<string, ArrayLifespan> array_lifespans;
+ ComputeArrayLifespans(*model, &array_lifespans);
+
+ // In case of variable batch, our convention will be to compute the
+ // allocations for batch==1, then let the inference code multiply all
+ // the offsets by the actual runtime batch size. Conveniently,
+ // the variable_batch and batch flags are mutually exclusive, and the default
+ // value of batch is 1, so we have nothing special to do here. Let us
+ // just guard this assumption with a CHECK:
+ bool batchless_input_shapes = true;
+ for (const auto& input_array : model->flags.input_arrays()) {
+ if (input_array.shape().empty() || input_array.shape(0) != 1) {
+ batchless_input_shapes = false;
+ break;
+ }
+ }
+ CHECK(!model->flags.variable_batch() || batchless_input_shapes);
+
+ Allocator allocator;
+
+ // Construct a sorted map of array names, so that other layout engines can
+ // match exactly.
+ std::map<string, const Array*> ordered_arrays_map;
+ for (const auto& pair : model->arrays) {
+ ordered_arrays_map[pair.first] = pair.second.get();
+ }
+
+ // Allocate persistent arrays (like RNN states). For them, 'transient'
+ // is a misnormer, should read 'workspace'.
+ for (const auto& array_pair : ordered_arrays_map) {
+ const string& array_name = array_pair.first;
+ const auto& array_lifespan = array_lifespans.find(array_name)->second;
+ if (array_lifespan.persistent) {
+ AllocateTransientArray(*model, array_name, &allocator,
+ transient_data_alignment);
+ }
+ }
+
+ for (std::size_t op_index = 0; op_index < model->operators.size();
+ op_index++) {
+ const auto& op = model->operators[op_index];
+ // Allocate those arrays whose lifespan starts exactly here.
+ for (const auto& input : op->inputs) {
+ if (StartsAt(array_lifespans[input], op_index)) {
+ AllocateTransientArray(*model, input, &allocator,
+ transient_data_alignment);
+ }
+ }
+ for (const auto& output : op->outputs) {
+ if (StartsAt(array_lifespans[output], op_index)) {
+ AllocateTransientArray(*model, output, &allocator,
+ transient_data_alignment);
+ }
+ }
+ // Deallocate those arrays whose lifespan ends exactly here.
+ for (const auto& input : op->inputs) {
+ if (EndsAt(array_lifespans[input], op_index)) {
+ DeallocateTransientArray(*model, input, &allocator);
+ }
+ }
+ for (const auto& output : op->outputs) {
+ if (EndsAt(array_lifespans[output], op_index)) {
+ DeallocateTransientArray(*model, output, &allocator);
+ }
+ }
+ }
+
+ // Just out of curiosity (not used in the actual allocation process)
+ // evaluate the optimal total allocated size.
+ // First, compute the size of persistent arrays.
+ std::size_t optimal_transient_alloc_size = 0;
+ std::size_t persistent_alloc_size = 0;
+ for (const auto& array_pair : ordered_arrays_map) {
+ const string& array_name = array_pair.first;
+ const auto& array_lifespan = array_lifespans.find(array_name)->second;
+ if (array_lifespan.persistent) {
+ persistent_alloc_size +=
+ TransientArraySize(*model, array_name, transient_data_alignment);
+ }
+ }
+ for (const auto& op : model->operators) {
+ // for each operator, compute the sum of the sizes of the array that must
+ // be live during the execution of this operator, plus the size of
+ // persistent arrays that must be live at all times.
+ std::size_t size = persistent_alloc_size;
+ for (const auto& input : op->inputs) {
+ if (!array_lifespans[input].persistent) {
+ size += TransientArraySize(*model, input, transient_data_alignment);
+ }
+ }
+ for (const auto& output : op->outputs) {
+ if (!array_lifespans[output].persistent) {
+ size += TransientArraySize(*model, output, transient_data_alignment);
+ }
+ }
+ // The optimal total size is the maximum of all operator-specific sizes.
+ optimal_transient_alloc_size = std::max(optimal_transient_alloc_size, size);
+ }
+
+ model->transient_data_size = allocator.total_size();
+ model->transient_data_alignment = transient_data_alignment;
+ CHECK_GE(model->transient_data_size, optimal_transient_alloc_size);
+ LOG(INFO) << "Total transient array allocated size: "
+ << model->transient_data_size << " bytes, "
+ << "theoretical optimal value: " << optimal_transient_alloc_size
+ << " bytes.";
+ CheckInvariants(*model);
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.h b/tensorflow/contrib/lite/toco/allocate_transient_arrays.h
new file mode 100644
index 0000000000..12d0d0498f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.h
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_
+
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+// We align the allocated sizes to the next multiple of a cache line,
+// to get simple performance characteristics without side effects of
+// accesses to one buffer on accesses to another buffer.
+// That also takes care of data type alignment for any reasonable type
+// (no reasonable data type should have alignment greater than a cache line).
+// Here we make CPU-centric assumptions, in particular, we assume 64-byte cache
+// lines. Getting this wrong by a factor of 2x (if this ever changes) wouldn't
+// be terrible.
+// Embedded architectures may use a different value for alignment.
+constexpr std::size_t kDefaultTransientDataAlignment = 64;
+
+// Rounds up dividend to a value divisible by divisor.
+inline std::size_t RoundUpToNextMultipleOf(std::size_t dividend,
+ std::size_t divisor) {
+ return ((dividend + divisor - 1) / divisor) * divisor;
+}
+
+void AllocateTransientArrays(Model* model,
+ std::size_t transient_data_alignment);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
new file mode 100644
index 0000000000..28661d4ff0
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -0,0 +1,225 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This abstracts command line arguments in toco.
+// Arg<T> is a parseable type that can register a default value, be able to
+// parse itself, and keep track of whether it was specified.
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
+
+#include <functional>
+#include <unordered_map>
+#include <vector>
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+
+namespace toco {
+
+// Since std::vector<int32> is in the std namespace, and we are not allowed
+// to add ParseFlag/UnparseFlag to std, we introduce a simple wrapper type
+// to use as the flag type:
+struct IntList {
+ std::vector<int32> elements;
+};
+struct StringMapList {
+ std::vector<std::unordered_map<string, string>> elements;
+};
+
+// command_line_flags.h don't track whether or not a flag is specified. Arg
+// contains the value (which will be default if not specified) and also
+// whether the flag is specified.
+// TODO(aselle): consider putting doc string and ability to construct the
+// tensorflow argument into this, so declaration of parameters can be less
+// distributed.
+// Every template specialization of Arg is required to implement
+// default_value(), specified(), value(), parse(), bind().
+template <class T>
+class Arg final {
+ public:
+ explicit Arg(T default_ = T()) : value_(default_) {}
+ virtual ~Arg() {}
+
+ // Provide default_value() to arg list
+ T default_value() const { return value_; }
+ // Return true if the command line argument was specified on the command line.
+ bool specified() const { return specified_; }
+ // Const reference to parsed value.
+ const T& value() const { return value_; }
+
+ // Parsing callback for the tensorflow::Flags code
+ bool parse(T value_in) {
+ value_ = value_in;
+ specified_ = true;
+ return true;
+ }
+
+ // Bind the parse member function so tensorflow::Flags can call it.
+ std::function<bool(T)> bind() {
+ return std::bind(&Arg::parse, this, std::placeholders::_1);
+ }
+
+ private:
+ // Becomes true after parsing if the value was specified
+ bool specified_ = false;
+ // Value of the argument (initialized to the default in the constructor).
+ T value_;
+};
+
+template <>
+class Arg<toco::IntList> final {
+ public:
+ // Provide default_value() to arg list
+ string default_value() const { return ""; }
+ // Return true if the command line argument was specified on the command line.
+ bool specified() const { return specified_; }
+ // Bind the parse member function so tensorflow::Flags can call it.
+ bool parse(string text) {
+ parsed_value_.elements.clear();
+ specified_ = true;
+ // strings::Split("") produces {""}, but we need {} on empty input.
+ // TODO(aselle): Moved this from elsewhere, but ahentz recommends we could
+ // use absl::SplitLeadingDec32Values(text.c_str(), &parsed_values_.elements)
+ if (!text.empty()) {
+ int32 element;
+ for (absl::string_view part : absl::StrSplit(text, ',')) {
+ if (!SimpleAtoi(part, &element)) return false;
+ parsed_value_.elements.push_back(element);
+ }
+ }
+ return true;
+ }
+
+ std::function<bool(string)> bind() {
+ return std::bind(&Arg::parse, this, std::placeholders::_1);
+ }
+
+ const toco::IntList& value() const { return parsed_value_; }
+
+ private:
+ toco::IntList parsed_value_;
+ bool specified_ = false;
+};
+
+template <>
+class Arg<toco::StringMapList> final {
+ public:
+ // Provide default_value() to StringMapList
+ string default_value() const { return ""; }
+ // Return true if the command line argument was specified on the command line.
+ bool specified() const { return specified_; }
+ // Bind the parse member function so tensorflow::Flags can call it.
+
+ bool parse(string text) {
+ parsed_value_.elements.clear();
+ specified_ = true;
+
+ if (text.empty()) {
+ return true;
+ }
+
+#if defined(PLATFORM_GOOGLE)
+ std::vector<absl::string_view> outer_vector;
+ absl::string_view text_disposable_copy = text;
+ SplitStructuredLine(text_disposable_copy, ',', "{}", &outer_vector);
+ for (const absl::string_view& outer_member_stringpiece : outer_vector) {
+ string outer_member(outer_member_stringpiece);
+ if (outer_member.empty()) {
+ continue;
+ }
+ string outer_member_copy = outer_member;
+ absl::StripAsciiWhitespace(&outer_member);
+ if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false;
+ if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false;
+ const std::vector<string> inner_fields_vector =
+ strings::Split(outer_member, ',');
+
+ std::unordered_map<string, string> element;
+ for (const string& member_field : inner_fields_vector) {
+ std::vector<string> outer_member_key_value =
+ strings::Split(member_field, ':');
+ if (outer_member_key_value.size() != 2) return false;
+ string& key = outer_member_key_value[0];
+ string& value = outer_member_key_value[1];
+ absl::StripAsciiWhitespace(&key);
+ absl::StripAsciiWhitespace(&value);
+ if (element.count(key) != 0) return false;
+ element[key] = value;
+ }
+ parsed_value_.elements.push_back(element);
+ }
+ return true;
+#else
+ // TODO(aselle): Fix argument parsing when absl supports structuredline
+ fprintf(stderr, "%s:%d StringMapList arguments not supported\n", __FILE__,
+ __LINE__);
+ abort();
+#endif
+ }
+
+ std::function<bool(string)> bind() {
+ return std::bind(&Arg::parse, this, std::placeholders::_1);
+ }
+
+ const toco::StringMapList& value() const { return parsed_value_; }
+
+ private:
+ toco::StringMapList parsed_value_;
+ bool specified_ = false;
+};
+
+// Flags that describe a model. See model_cmdline_flags.cc for details.
+struct ParsedModelFlags {
+ Arg<string> input_array;
+ Arg<string> input_arrays;
+ Arg<string> output_array;
+ Arg<string> output_arrays;
+ Arg<string> input_shapes;
+ Arg<float> mean_value = Arg<float>(0.f);
+ Arg<string> mean_values;
+ Arg<float> std_value = Arg<float>(1.f);
+ Arg<string> std_values;
+ Arg<bool> variable_batch = Arg<bool>(false);
+ Arg<bool> drop_control_dependency = Arg<bool>(false);
+ Arg<toco::IntList> input_shape;
+ Arg<toco::StringMapList> rnn_states;
+ Arg<toco::StringMapList> model_checks;
+ // Debugging output options
+ Arg<string> graphviz_first_array;
+ Arg<string> graphviz_last_array;
+ Arg<string> dump_graphviz;
+ Arg<bool> dump_graphviz_video = Arg<bool>(false);
+};
+
+// Flags that describe the operation you would like to do (what conversion
+// you want). See toco_cmdline_flags.cc for details.
+struct ParsedTocoFlags {
+ Arg<string> input_file;
+ Arg<string> output_file;
+ Arg<string> input_format;
+ Arg<string> output_format;
+ // TODO(aselle): command_line_flags doesn't support doubles
+ Arg<float> default_ranges_min = Arg<float>(0.);
+ Arg<float> default_ranges_max = Arg<float>(0.);
+ Arg<string> input_type;
+ Arg<string> input_types;
+ Arg<string> inference_type;
+ Arg<bool> drop_fake_quant = Arg<bool>(false);
+ Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
+ Arg<bool> allow_custom_ops = Arg<bool>(false);
+};
+
+} // namespace toco
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc
new file mode 100644
index 0000000000..f5e2868dc0
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc
@@ -0,0 +1,293 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
+
+#include <memory>
+#include <set>
+#include <unordered_set>
+#include <vector>
+
+#include "absl/strings/str_replace.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+using toco::port::AppendF;
+using toco::port::StringF;
+
+namespace toco {
+namespace {
+
+class Color {
+ public:
+ Color() {}
+ Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {}
+ // Returns the string serialization of this color in graphviz format,
+ // for use as 'fillcolor' in boxes.
+ string FillColorString() const { return StringF("%.2X%.2X%.2X", r_, g_, b_); }
+ // Returns the serialization in graphviz format of a suitable color to use
+ // 'fontcolor' in the same boxes. It should black or white, whichever offers
+ // the better contrast from FillColorString().
+ string TextColorString() const {
+ // https://en.wikipedia.org/wiki/Relative_luminance
+ const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_;
+ const uint8 l = luminance > 128.f ? 0 : 255;
+ return StringF("%.2X%.2X%.2X", l, l, l);
+ }
+
+ private:
+ uint8 r_ = 0, g_ = 0, b_ = 0;
+};
+
+struct NodeProperties {
+ // The text to display inside the box for this node.
+ string label;
+ // The color to use for this node; will be used as 'fillcolor'
+ // for its box. See Color::FillColorString. A suitable, different
+ // color will be chosen for the 'fontcolor' for the inside text
+ // label, see Color::TextColorString.
+ Color color;
+};
+
+// All colors in this file are from:
+// https://material.io/guidelines/style/color.html
+
+Color GetColorForArray(const Model& model, const string& array_name) {
+ // Arrays involved in RNN back-edges have a different color
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ // RNN state, fed by a back-edge. Bold color.
+ if (array_name == rnn_state.state_array()) {
+ return Color(0x0F, 0x9D, 0x58);
+ }
+ // RNN back-edge source, feeding a RNN state.
+ // Light tone of the same color as RNN states.
+ if (array_name == rnn_state.back_edge_source_array()) {
+ return Color(0xB7, 0xE1, 0xCD);
+ }
+ }
+ // Constant parameter arrays have their own bold color
+ if (model.GetArray(array_name).buffer) {
+ return Color(0x42, 0x85, 0xF4);
+ }
+ // Remaining arrays are activations.
+ // We use gray colors for them because they are the majority
+ // of arrays so we want to highlight other arrays instead of them.
+ // First, we use a bolder gray for input/output arrays:
+ const auto& dump_options = *GraphVizDumpOptions::singleton();
+ if (IsInputArray(model, array_name) ||
+ array_name == dump_options.graphviz_first_array ||
+ array_name == dump_options.graphviz_last_array) {
+ return Color(0x9E, 0x9E, 0x9E);
+ }
+ for (const string& output_array : model.flags.output_arrays()) {
+ if (array_name == output_array) {
+ return Color(0x9E, 0x9E, 0x9E);
+ }
+ }
+ // Remaining arrays are intermediate activation arrays.
+ // Lighter tone of the same grey as for input/output arrays:
+ // We want these to be very discrete.
+ return Color(0xF5, 0xF5, 0xF5);
+}
+
+NodeProperties GetPropertiesForArray(const Model& model,
+ const string& array_name) {
+ NodeProperties node_properties;
+ node_properties.color = GetColorForArray(model, array_name);
+ node_properties.label = absl::StrReplaceAll(array_name, {{"/", "/\\n"}});
+
+ // Append array shape to the label.
+ auto& array = model.GetArray(array_name);
+
+ if (array.data_type == ArrayDataType::kFloat) {
+ AppendF(&node_properties.label, "\\nType: float");
+ } else if (array.data_type == ArrayDataType::kInt32) {
+ AppendF(&node_properties.label, "\\nType: int32");
+ } else if (array.data_type == ArrayDataType::kUint8) {
+ AppendF(&node_properties.label, "\\nType: uint8");
+ }
+
+ if (array.has_shape()) {
+ auto& array_shape = array.shape();
+ node_properties.label += "\\n[";
+ for (int id = 0; id < array_shape.dimensions_count(); id++) {
+ if (id == 0) {
+ AppendF(&node_properties.label, "%d", array_shape.dims(id));
+ } else {
+ AppendF(&node_properties.label, "x%d", array_shape.dims(id));
+ }
+ }
+ node_properties.label += "]";
+ }
+
+ if (array.minmax) {
+ AppendF(&node_properties.label, "\\nMinMax: [%.3g, %.3g]",
+ array.minmax->min, array.minmax->max);
+ }
+
+ if (array.quantization_params) {
+ AppendF(&node_properties.label, "\\nQuantization: %.3g * (x - %d)",
+ array.quantization_params->scale,
+ array.quantization_params->zero_point);
+ }
+
+ if (array.alloc) {
+ AppendF(&node_properties.label, "\\nTransient Alloc: [%d, %d)",
+ array.alloc->start, array.alloc->end);
+ }
+
+ return node_properties;
+}
+
+NodeProperties GetPropertiesForOperator(const Operator& op) {
+ NodeProperties node_properties;
+ if (op.type == OperatorType::kTensorFlowUnsupported) {
+ node_properties.label =
+ static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
+ } else {
+ node_properties.label = OperatorTypeName(op.type);
+ }
+ // Additional information for some of the operators.
+ switch (op.type) {
+ case OperatorType::kConv: {
+ const auto& conv_op = static_cast<const ConvOperator&>(op);
+ node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
+ AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width,
+ conv_op.stride_height,
+ conv_op.padding.type == PaddingType::kSame ? "S" : "V");
+ break;
+ }
+ case OperatorType::kDepthwiseConv: {
+ const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op);
+ node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
+ AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width,
+ conv_op.stride_height,
+ conv_op.padding.type == PaddingType::kSame ? "S" : "V");
+ break;
+ }
+ case OperatorType::kFullyConnected: {
+ node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
+ break;
+ }
+ default:
+ node_properties.color = Color(0xDB, 0x44, 0x37);
+ break;
+ }
+
+ return node_properties;
+}
+
+std::vector<const Operator*> OperatorsToDump(const Model& model) {
+ const auto& dump_options = *GraphVizDumpOptions::singleton();
+ bool first_specified = !dump_options.graphviz_first_array.empty();
+ bool last_specified = !dump_options.graphviz_last_array.empty();
+ CHECK_EQ(first_specified, last_specified);
+ std::vector<const Operator*> ops_to_dump;
+ if (last_specified) {
+ // Return only the part of the graph between graphviz_first_array
+ // and graphviz_last_array.
+ CHECK(model.arrays.count(dump_options.graphviz_first_array));
+ CHECK(model.arrays.count(dump_options.graphviz_last_array));
+ std::unordered_set<string> arrays_already_produced;
+ std::vector<string> arrays_to_produce;
+ arrays_to_produce.push_back(dump_options.graphviz_last_array);
+ while (!arrays_to_produce.empty()) {
+ const string array = arrays_to_produce.back();
+ arrays_to_produce.pop_back();
+ CHECK(!arrays_already_produced.count(array));
+ arrays_already_produced.insert(array);
+ const Operator* op = GetOpWithOutput(model, array);
+ if (!op) {
+ continue;
+ }
+ ops_to_dump.push_back(op);
+ for (const string& input : op->inputs) {
+ if (arrays_already_produced.count(input) ||
+ input == dump_options.graphviz_first_array) {
+ continue;
+ }
+ arrays_to_produce.push_back(input);
+ }
+ }
+ } else {
+ // Return the whole graph.
+ for (const auto& op : model.operators) {
+ ops_to_dump.push_back(op.get());
+ }
+ }
+ return ops_to_dump;
+}
+
+} // namespace
+
+void DumpGraphviz(const Model& model, string* output_file_contents) {
+ AppendF(output_file_contents, "digraph Computegraph {\n");
+
+ constexpr char kNodeFormat[] =
+ "\t \"%s\" [label=\"%s\", shape=%s, style=filled, fillcolor=\"#%s\", "
+ "fontcolor = \"#%sDD\"];\n";
+
+ constexpr char kEdgeFormat[] = "\t \"%s\" -> \"%s\";\n";
+
+ constexpr char kRNNBackEdgeFormat[] =
+ "\t \"%s\" -> \"%s\" [color=\"#0F9D58\"];\n";
+
+ std::vector<const Operator*> ops_to_dump = OperatorsToDump(model);
+ std::set<string> already_added_arrays;
+ for (int op_index = 0; op_index < ops_to_dump.size(); op_index++) {
+ const Operator& op = *ops_to_dump[op_index];
+ // Add node for operator.
+ auto op_properties = GetPropertiesForOperator(op);
+ string operator_id = StringF("op%05d", op_index);
+ AppendF(output_file_contents, kNodeFormat, operator_id, op_properties.label,
+ "box", op_properties.color.FillColorString().c_str(),
+ op_properties.color.TextColorString().c_str());
+ // Add nodes and edges for all inputs of the operator.
+ for (const auto& input : op.inputs) {
+ auto array_properties = GetPropertiesForArray(model, input);
+ if (!already_added_arrays.count(input)) {
+ AppendF(output_file_contents, kNodeFormat, input,
+ array_properties.label, "octagon",
+ array_properties.color.FillColorString().c_str(),
+ array_properties.color.TextColorString().c_str());
+ }
+ AppendF(output_file_contents, kEdgeFormat, input, operator_id);
+ already_added_arrays.insert(input);
+ }
+ // Add nodes and edges for all outputs of the operator.
+ for (const auto& output : op.outputs) {
+ auto array_properties = GetPropertiesForArray(model, output);
+ if (!already_added_arrays.count(output)) {
+ AppendF(output_file_contents, kNodeFormat, output,
+ array_properties.label, "octagon",
+ array_properties.color.FillColorString().c_str(),
+ array_properties.color.TextColorString().c_str());
+ }
+ AppendF(output_file_contents, kEdgeFormat, operator_id, output);
+ already_added_arrays.insert(output);
+ }
+ }
+
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ AppendF(output_file_contents, kRNNBackEdgeFormat,
+ rnn_state.back_edge_source_array(), rnn_state.state_array());
+ }
+
+ AppendF(output_file_contents, "}\n");
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.h b/tensorflow/contrib/lite/toco/dump_graphviz.h
new file mode 100644
index 0000000000..0fb28e3de8
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.h
@@ -0,0 +1,28 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_
+
+#include <string>
+
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+void DumpGraphviz(const Model& model, string* output_file_contents);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
new file mode 100644
index 0000000000..16b9fa2260
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -0,0 +1,1570 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "google/protobuf/map.h"
+#include "google/protobuf/text_format.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_util.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/logging.h"
+
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::GraphDef;
+using tensorflow::TensorProto;
+
+namespace toco {
+namespace {
+
+// TensorFlow sometimes forbids what it calls "legacy scalars",
+// which are 1-D shapes where the unique shape size is 1.
+// See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
+// For that reason, we generally avoid creating legacy scalars,
+// by detecting the case where a 1-D shape would be of size 1 and
+// replacing that by a 0-D shape.
+// However, there is a special circumstance where we must not do that
+// and must unconditionally create a 1-D shape even if it is going to
+// be of size 1: that is the case of bias vectors, with BiasAdd nodes.
+// Indeed, TensorFlow requires bias vectors to be 1-D; in the case of
+// a depth of 1, that would be a legacy scalar, so in that case we
+// must go ahead and keep the shape 1-D, letting it be a legacy scalar.
+enum class LegacyScalarPolicy { kAvoidLegacyScalars, kDoCreateLegacyScalars };
+
+void ExportFloatArray(const Shape& input_shape, const float* input_data,
+ TensorProto* output_tensor,
+ LegacyScalarPolicy legacy_scalar_policy) {
+ output_tensor->set_dtype(DT_FLOAT);
+ const int input_flat_size = RequiredBufferSizeForShape(input_shape);
+ auto* shape = output_tensor->mutable_tensor_shape();
+
+ const int kDims = input_shape.dimensions_count();
+ if (legacy_scalar_policy == LegacyScalarPolicy::kDoCreateLegacyScalars ||
+ kDims > 1 || (kDims == 1 && input_shape.dims(0) > 1)) {
+ for (int i = 0; i < kDims; ++i) {
+ shape->add_dim()->set_size(input_shape.dims(i));
+ }
+ }
+ output_tensor->set_tensor_content(
+ string(reinterpret_cast<const char*>(input_data),
+ sizeof(*input_data) * input_flat_size));
+}
+
+void ExportFloatArray(AxesOrder input_axes_order, const Shape& input_shape,
+ const float* input_data, AxesOrder output_axes_order,
+ TensorProto* output_tensor,
+ LegacyScalarPolicy legacy_scalar_policy) {
+ CHECK_EQ(AxesCount(output_axes_order), AxesCount(input_axes_order));
+ output_tensor->set_dtype(DT_FLOAT);
+ CHECK_EQ(input_shape.dimensions_count(), AxesCount(input_axes_order));
+ const int input_flat_size = RequiredBufferSizeForShape(input_shape);
+
+ Shape shuffled_shape;
+ ShuffleDims(input_shape, input_axes_order, output_axes_order,
+ &shuffled_shape);
+ std::vector<float> shuffled_data(input_flat_size);
+ ShuffleArray(input_shape, input_axes_order, output_axes_order, shuffled_shape,
+ input_data, shuffled_data.data());
+
+ ExportFloatArray(shuffled_shape, shuffled_data.data(), output_tensor,
+ legacy_scalar_policy);
+}
+
+bool HasAlreadyExportedConst(const string& name,
+ const GraphDef& tensorflow_graph) {
+ for (const auto& node : tensorflow_graph.node()) {
+ if (node.op() == "Const" && node.name() == name) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void ConvertFloatTensorConst(const string& name, const Shape& input_shape,
+ const float* input_data,
+ AxesOrder input_axes_order,
+ AxesOrder output_axes_order,
+ GraphDef* tensorflow_graph,
+ LegacyScalarPolicy legacy_scalar_policy) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
+ tensor, legacy_scalar_policy);
+}
+
+void ConvertFloatTensorConst(const string& name, const Shape& input_shape,
+ const float* input_data,
+ AxesOrder input_axes_order,
+ AxesOrder output_axes_order,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
+ tensor, LegacyScalarPolicy::kAvoidLegacyScalars);
+}
+
+void ConvertFloatTensorConst(const Model& model, const string& name,
+ AxesOrder input_axes_order,
+ AxesOrder output_axes_order,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ CHECK(model.arrays.count(name));
+ const auto& input_array = *model.arrays.at(name);
+ const auto& input_shape = input_array.shape();
+ CHECK(input_array.buffer);
+ CHECK(input_array.buffer->type == ArrayDataType::kFloat);
+ const float* input_data =
+ input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
+ tensor, LegacyScalarPolicy::kAvoidLegacyScalars);
+}
+
+void ConvertFloatTensorConst(const Model& model, const string& name,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ CHECK(model.arrays.count(name));
+ const auto& input_array = *model.arrays.at(name);
+ const auto& input_shape = input_array.shape();
+ CHECK(input_array.buffer);
+ CHECK(input_array.buffer->type == ArrayDataType::kFloat);
+ const float* input_data =
+ input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ExportFloatArray(input_shape, input_data, tensor,
+ LegacyScalarPolicy::kAvoidLegacyScalars);
+}
+
+void ConvertIntTensorConst(const Model& model, const string& name,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ CHECK(model.arrays.count(name));
+ const auto& array = *model.arrays.at(name);
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+ const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
+ for (auto index : data) {
+ tensor->add_int_val(index);
+ }
+ const auto& array_shape = array.shape();
+ auto* shape = tensor->mutable_tensor_shape();
+ for (int i = 0; i < array_shape.dimensions_count(); i++) {
+ shape->add_dim()->set_size(array_shape.dims(i));
+ }
+}
+
+void CreateMatrixShapeTensorConst(const string& name, int rows, int cols,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+ const int32 data[2] = {cols, rows};
+ tensor->set_tensor_content(
+ string(reinterpret_cast<const char*>(data), sizeof(data)));
+ auto* shape = tensor->mutable_tensor_shape();
+ shape->add_dim()->set_size(2);
+}
+
+void CreateDummyConcatDimTensorConst(const string& name, int dim,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+ tensor->add_int_val(dim);
+}
+
+void CreateReshapeShapeTensorConst(const string& name,
+ const std::vector<int32>& shape,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+ for (auto s : shape) {
+ tensor->add_int_val(s);
+ }
+ // TensorFlow sometimes forbids what it calls "legacy scalars",
+ // which are shapes of size 1 where the unique shape size is 1.
+ // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
+ if (shape.size() > 1) {
+ auto* tensor_shape = tensor->mutable_tensor_shape();
+ tensor_shape->add_dim()->set_size(shape.size());
+ }
+}
+
+string WalkUpToConstantArray(const Model& model, const string& name) {
+ const Array& original_array = model.GetArray(name);
+ if (original_array.buffer) {
+ return name;
+ }
+ const auto* op = GetOpWithOutput(model, name);
+ CHECK(op);
+ CHECK(op->type == OperatorType::kFakeQuant);
+ const string& input_of_fakequant_name = op->inputs[0];
+ const Array& input_of_fakequant = model.GetArray(input_of_fakequant_name);
+ CHECK(input_of_fakequant.buffer);
+ return input_of_fakequant_name;
+}
+
+void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ const bool has_bias = src_op.inputs.size() >= 3;
+ string conv_output = src_op.outputs[0];
+ if (has_bias) {
+ conv_output += "/conv";
+ }
+
+ auto* conv2d_op = tensorflow_graph->add_node();
+ conv2d_op->set_op("Conv2D");
+ conv2d_op->set_name(conv_output);
+ *conv2d_op->add_input() = src_op.inputs[0];
+ *conv2d_op->add_input() = src_op.inputs[1];
+ (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ const string& weights_array_name =
+ WalkUpToConstantArray(model, src_op.inputs[1]);
+ const auto& weights_array = model.GetArray(weights_array_name);
+ CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
+ ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
+ AxesOrder::kHWIO, tensorflow_graph);
+ auto& strides = (*conv2d_op->mutable_attr())["strides"];
+ strides.mutable_list()->add_i(1);
+ strides.mutable_list()->add_i(src_op.stride_height);
+ strides.mutable_list()->add_i(src_op.stride_width);
+ strides.mutable_list()->add_i(1);
+ string padding;
+ if (src_op.padding.type == PaddingType::kSame) {
+ padding = "SAME";
+ } else if (src_op.padding.type == PaddingType::kValid) {
+ padding = "VALID";
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ (*conv2d_op->mutable_attr())["padding"].set_s(padding);
+
+ if (has_bias) {
+ auto* biasadd_op = tensorflow_graph->add_node();
+ biasadd_op->set_op("BiasAdd");
+ biasadd_op->set_name(src_op.outputs[0]);
+ biasadd_op->add_input(conv_output);
+ biasadd_op->add_input(src_op.inputs[2]);
+ (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ CHECK(model.arrays.count(src_op.inputs[2]));
+ const string& bias_array_name =
+ WalkUpToConstantArray(model, src_op.inputs[2]);
+ const auto& bias_array = model.GetArray(bias_array_name);
+ // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
+ Shape bias_shape_1d = bias_array.shape();
+ UnextendShape(&bias_shape_1d, 1);
+ CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
+ const float* bias_data =
+ bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ConvertFloatTensorConst(bias_array_name, bias_shape_1d, bias_data,
+ AxesOrder::kOneAxis, AxesOrder::kOneAxis,
+ tensorflow_graph,
+ LegacyScalarPolicy::kDoCreateLegacyScalars);
+ }
+}
+
+void ConvertDepthwiseConvOperator(const Model& model,
+ const DepthwiseConvOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ const bool has_bias = src_op.inputs.size() >= 3;
+ string conv_output = src_op.outputs[0];
+ if (has_bias) {
+ conv_output += "/conv";
+ }
+
+ auto* dc2d_op = tensorflow_graph->add_node();
+ dc2d_op->set_op("DepthwiseConv2dNative");
+ dc2d_op->set_name(conv_output);
+ *dc2d_op->add_input() = src_op.inputs[0];
+ *dc2d_op->add_input() = src_op.inputs[1];
+ (*dc2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ // Our internal DepthwiseConv weights are 1 x H x W x OutputDepth.
+ // We need to convert that to H x W x InputDepth x Multiplier.
+ // That's only a matter of constructing a Dims object; the actual
+ // array layout is the same.
+ CHECK(model.arrays.count(src_op.inputs[1]));
+ const string& src_weights_name =
+ WalkUpToConstantArray(model, src_op.inputs[1]);
+ const auto& src_weights_array = model.GetArray(src_weights_name);
+ const auto& src_weights_shape = src_weights_array.shape();
+ CHECK_EQ(src_weights_shape.dimensions_count(), 4);
+ const Shape dst_weights_shape =
+ Shape({src_weights_shape.dims(1), src_weights_shape.dims(2),
+ src_weights_shape.dims(3) / src_op.depth_multiplier,
+ src_op.depth_multiplier});
+ CHECK_EQ(src_weights_shape.dims(3) % src_op.depth_multiplier, 0);
+ CHECK(dst_weights_shape.dims(2) * dst_weights_shape.dims(3) ==
+ src_weights_shape.dims(3));
+ CHECK_EQ(src_weights_shape.dims(0), 1);
+
+ CHECK(src_weights_array.buffer->type == ArrayDataType::kFloat);
+ const float* src_weights_data =
+ src_weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ConvertFloatTensorConst(src_weights_name, dst_weights_shape, src_weights_data,
+ AxesOrder::kHWIM, AxesOrder::kHWIM, tensorflow_graph);
+
+ auto& strides = (*dc2d_op->mutable_attr())["strides"];
+ strides.mutable_list()->add_i(1);
+ strides.mutable_list()->add_i(src_op.stride_height);
+ strides.mutable_list()->add_i(src_op.stride_width);
+ strides.mutable_list()->add_i(1);
+ string padding;
+ if (src_op.padding.type == PaddingType::kSame) {
+ padding = "SAME";
+ } else if (src_op.padding.type == PaddingType::kValid) {
+ padding = "VALID";
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ (*dc2d_op->mutable_attr())["padding"].set_s(padding);
+
+ if (has_bias) {
+ auto* biasadd_op = tensorflow_graph->add_node();
+ biasadd_op->set_op("BiasAdd");
+ biasadd_op->set_name(src_op.outputs[0]);
+ biasadd_op->add_input(conv_output);
+ biasadd_op->add_input(src_op.inputs[2]);
+ (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ CHECK(model.arrays.count(src_op.inputs[2]));
+ const string& bias_name = WalkUpToConstantArray(model, src_op.inputs[2]);
+ const auto& bias_array = model.GetArray(bias_name);
+ // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
+ Shape bias_shape_1d = bias_array.shape();
+ UnextendShape(&bias_shape_1d, 1);
+ CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
+ const float* bias_data =
+ bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ConvertFloatTensorConst(bias_name, bias_shape_1d, bias_data,
+ AxesOrder::kOneAxis, AxesOrder::kOneAxis,
+ tensorflow_graph,
+ LegacyScalarPolicy::kDoCreateLegacyScalars);
+ }
+}
+
+void ConvertDepthToSpaceOperator(const Model& model,
+ const DepthToSpaceOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* op = tensorflow_graph->add_node();
+ op->set_op("DepthToSpace");
+ op->set_name(src_op.outputs[0]);
+ *op->add_input() = src_op.inputs[0];
+ (*op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
+}
+
+void ConvertSpaceToDepthOperator(const Model& model,
+ const SpaceToDepthOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* op = tensorflow_graph->add_node();
+ op->set_op("SpaceToDepth");
+ op->set_name(src_op.outputs[0]);
+ *op->add_input() = src_op.inputs[0];
+ (*op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
+}
+
+void ConvertFullyConnectedOperator(const Model& model,
+ const FullyConnectedOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ const string reshape_output = src_op.outputs[0] + "/reshape";
+ const string reshape_shape = src_op.outputs[0] + "/reshape/shape";
+ auto* reshape_op = tensorflow_graph->add_node();
+ reshape_op->set_op("Reshape");
+ reshape_op->set_name(reshape_output);
+ reshape_op->add_input(src_op.inputs[0]);
+ reshape_op->add_input(reshape_shape);
+ (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const bool has_bias = src_op.inputs.size() >= 3;
+ string matmul_output = src_op.outputs[0];
+ if (has_bias) {
+ matmul_output += "/matmul";
+ }
+
+ auto* matmul_op = tensorflow_graph->add_node();
+ matmul_op->set_op("MatMul");
+
+ matmul_op->set_name(matmul_output);
+ *matmul_op->add_input() = reshape_output;
+ *matmul_op->add_input() = src_op.inputs[1];
+ (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
+ (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
+ CHECK(model.arrays.count(src_op.inputs[1]));
+ const string& fc_weights_name =
+ WalkUpToConstantArray(model, src_op.inputs[1]);
+ const auto& fc_weights_array = *model.arrays.at(fc_weights_name);
+ const auto& fc_weights_shape = fc_weights_array.shape();
+ CHECK_EQ(fc_weights_shape.dimensions_count(), 2);
+ CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1,
+ tensorflow_graph);
+
+ CHECK(fc_weights_array.buffer);
+ CHECK(fc_weights_array.buffer->type == ArrayDataType::kFloat);
+ const float* fc_weights_data =
+ fc_weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ConvertFloatTensorConst(fc_weights_name, fc_weights_shape, fc_weights_data,
+ AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph);
+
+ if (has_bias) {
+ auto* biasadd_op = tensorflow_graph->add_node();
+ biasadd_op->set_op("BiasAdd");
+ biasadd_op->set_name(src_op.outputs[0]);
+ biasadd_op->add_input(matmul_output);
+ biasadd_op->add_input(src_op.inputs[2]);
+ (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ CHECK(model.arrays.count(src_op.inputs[2]));
+ const auto& bias_array = *model.arrays.at(src_op.inputs[2]);
+ // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
+ Shape bias_shape_1d = bias_array.shape();
+ UnextendShape(&bias_shape_1d, 1);
+ CHECK(bias_array.buffer);
+ CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
+ const float* bias_data =
+ bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ConvertFloatTensorConst(WalkUpToConstantArray(model, src_op.inputs[2]),
+ bias_shape_1d, bias_data, AxesOrder::kOneAxis,
+ AxesOrder::kOneAxis, tensorflow_graph,
+ LegacyScalarPolicy::kDoCreateLegacyScalars);
+ }
+}
+
+void ConvertAddOperator(const Model& model, const AddOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* add_op = tensorflow_graph->add_node();
+ add_op->set_op("Add");
+ add_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *add_op->add_input() = src_op.inputs[0];
+ *add_op->add_input() = src_op.inputs[1];
+ (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertMulOperator(const Model& model, const MulOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* add_op = tensorflow_graph->add_node();
+ add_op->set_op("Mul");
+ add_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *add_op->add_input() = src_op.inputs[0];
+ *add_op->add_input() = src_op.inputs[1];
+ (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertReluOperator(const ReluOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* relu_op = tensorflow_graph->add_node();
+ relu_op->set_op("Relu");
+ relu_op->set_name(src_op.outputs[0]);
+ *relu_op->add_input() = src_op.inputs[0];
+ (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertRelu1Operator(const Relu1Operator& src_op,
+ GraphDef* tensorflow_graph) {
+ const string max_bounds = src_op.outputs[0] + "/max_bounds";
+ const string min_bounds = src_op.outputs[0] + "/min_bounds";
+ const string max_output = src_op.outputs[0] + "/max_output";
+
+ auto* max_bounds_const_op = tensorflow_graph->add_node();
+ max_bounds_const_op->set_op("Const");
+ max_bounds_const_op->set_name(max_bounds);
+ (*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
+ auto* max_bounds_const_op_tensor =
+ (*max_bounds_const_op->mutable_attr())["value"].mutable_tensor();
+ max_bounds_const_op_tensor->set_dtype(DT_FLOAT);
+ max_bounds_const_op_tensor->add_float_val(-1.0f);
+
+ auto* min_bounds_const_op = tensorflow_graph->add_node();
+ min_bounds_const_op->set_op("Const");
+ min_bounds_const_op->set_name(min_bounds);
+ (*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
+ auto* min_bounds_const_op_tensor =
+ (*min_bounds_const_op->mutable_attr())["value"].mutable_tensor();
+ min_bounds_const_op_tensor->set_dtype(DT_FLOAT);
+ min_bounds_const_op_tensor->add_float_val(1.0f);
+
+ auto* max_op = tensorflow_graph->add_node();
+ max_op->set_op("Maximum");
+ max_op->set_name(max_output);
+ *max_op->add_input() = src_op.inputs[0];
+ *max_op->add_input() = max_bounds;
+ (*max_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ auto* min_op = tensorflow_graph->add_node();
+ min_op->set_op("Minimum");
+ min_op->set_name(src_op.outputs[0]);
+ *min_op->add_input() = max_output;
+ *min_op->add_input() = min_bounds;
+ (*min_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertRelu6Operator(const Relu6Operator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* relu_op = tensorflow_graph->add_node();
+ relu_op->set_op("Relu6");
+ relu_op->set_name(src_op.outputs[0]);
+ *relu_op->add_input() = src_op.inputs[0];
+ (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertLogisticOperator(const LogisticOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* relu_op = tensorflow_graph->add_node();
+ relu_op->set_op("Sigmoid");
+ relu_op->set_name(src_op.outputs[0]);
+ *relu_op->add_input() = src_op.inputs[0];
+ (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertTanhOperator(const TanhOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* tanh_op = tensorflow_graph->add_node();
+ tanh_op->set_op("Tanh");
+ tanh_op->set_name(src_op.outputs[0]);
+ *tanh_op->add_input() = src_op.inputs[0];
+ (*tanh_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ string softmax_input;
+ Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
+ if (providing_op->type == OperatorType::kTensorFlowReshape) {
+ softmax_input = src_op.inputs[0];
+ } else {
+ // Insert a reshape operator that reduces the dimensions down to the 2 that
+ // are required for TensorFlow Logits.
+ const string reshape_output = src_op.outputs[0] + "/softmax_insert_reshape";
+ const string softmax_size = src_op.outputs[0] + "/softmax_insert_size";
+ softmax_input = reshape_output;
+
+ auto* reshape_op = tensorflow_graph->add_node();
+ reshape_op->set_op("Reshape");
+ reshape_op->set_name(reshape_output);
+ *reshape_op->add_input() = src_op.inputs[0];
+ *reshape_op->add_input() = softmax_size;
+ (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const auto& input_shape = model.arrays.at(src_op.inputs[0])->shape();
+ int32 flattened_size = 1;
+ for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
+ flattened_size *= input_shape.dims(i);
+ }
+ const std::vector<int32> shape_data = {
+ flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
+ CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
+ }
+
+ auto* softmax_op = tensorflow_graph->add_node();
+ softmax_op->set_op("Softmax");
+ softmax_op->set_name(src_op.outputs[0]);
+ *softmax_op->add_input() = softmax_input;
+ // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter
+ CHECK_EQ(src_op.beta, 1.f);
+ (*softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ const string square_output = src_op.outputs[0] + "/square";
+ const string sum_reduction_indices = src_op.outputs[0] + "/reduction_indices";
+ const string sum_output = src_op.outputs[0] + "/sum";
+ const string rsqrt_output = src_op.outputs[0] + "/rsqrt";
+ const string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled";
+
+ auto* sum_reduction_indices_op = tensorflow_graph->add_node();
+ sum_reduction_indices_op->set_op("Const");
+ sum_reduction_indices_op->set_name(sum_reduction_indices);
+ (*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* sum_reduction_indices_tensor =
+ (*sum_reduction_indices_op->mutable_attr())["value"].mutable_tensor();
+ sum_reduction_indices_tensor->set_dtype(DT_INT32);
+ auto* sum_reduction_indices_shape =
+ sum_reduction_indices_tensor->mutable_tensor_shape();
+ auto* sum_reduction_indices_dim = sum_reduction_indices_shape->add_dim();
+ sum_reduction_indices_dim->set_size(2);
+ sum_reduction_indices_tensor->add_int_val(0);
+ sum_reduction_indices_tensor->add_int_val(1);
+
+ auto* square_op = tensorflow_graph->add_node();
+ square_op->set_op("Square");
+ square_op->set_name(square_output);
+ *square_op->add_input() = src_op.inputs[0];
+ (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ auto* sum_op = tensorflow_graph->add_node();
+ sum_op->set_op("Sum");
+ sum_op->set_name(sum_output);
+ *sum_op->add_input() = square_output;
+ *sum_op->add_input() = sum_reduction_indices;
+ (*sum_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ auto* rsqrt_op = tensorflow_graph->add_node();
+ rsqrt_op->set_op("Rsqrt");
+ rsqrt_op->set_name(rsqrt_output);
+ *rsqrt_op->add_input() = sum_output;
+ (*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ auto* mul_op = tensorflow_graph->add_node();
+ mul_op->set_op("Mul");
+ mul_op->set_name(src_op.outputs[0]);
+ *mul_op->add_input() = src_op.inputs[0];
+ *mul_op->add_input() = rsqrt_output;
+ (*mul_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertLocalResponseNormalizationOperator(
+ const LocalResponseNormalizationOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* lrn_op = tensorflow_graph->add_node();
+ lrn_op->set_op("LRN");
+ lrn_op->set_name(src_op.outputs[0]);
+ *lrn_op->add_input() = src_op.inputs[0];
+ (*lrn_op->mutable_attr())["depth_radius"].set_i(src_op.range);
+ (*lrn_op->mutable_attr())["bias"].set_f(src_op.bias);
+ (*lrn_op->mutable_attr())["alpha"].set_f(src_op.alpha);
+ (*lrn_op->mutable_attr())["beta"].set_f(src_op.beta);
+}
+
+void ConvertFakeQuantOperator(const FakeQuantOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* fakequant_op = tensorflow_graph->add_node();
+ fakequant_op->set_op("FakeQuantWithMinMaxArgs");
+ fakequant_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *fakequant_op->add_input() = src_op.inputs[0];
+ CHECK(src_op.minmax);
+ (*fakequant_op->mutable_attr())["min"].set_f(src_op.minmax->min);
+ (*fakequant_op->mutable_attr())["max"].set_f(src_op.minmax->max);
+}
+
+void ConvertMaxPoolOperator(const MaxPoolOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* maxpool_op = tensorflow_graph->add_node();
+ maxpool_op->set_op("MaxPool");
+ maxpool_op->set_name(src_op.outputs[0]);
+ *maxpool_op->add_input() = src_op.inputs[0];
+ auto& strides = (*maxpool_op->mutable_attr())["strides"];
+ strides.mutable_list()->add_i(1);
+ strides.mutable_list()->add_i(src_op.stride_height);
+ strides.mutable_list()->add_i(src_op.stride_width);
+ strides.mutable_list()->add_i(1);
+ string padding;
+ if (src_op.padding.type == PaddingType::kSame) {
+ padding = "SAME";
+ } else if (src_op.padding.type == PaddingType::kValid) {
+ padding = "VALID";
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ (*maxpool_op->mutable_attr())["padding"].set_s(padding);
+ (*maxpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ auto& ksize = (*maxpool_op->mutable_attr())["ksize"];
+ ksize.mutable_list()->add_i(1);
+ ksize.mutable_list()->add_i(src_op.kheight);
+ ksize.mutable_list()->add_i(src_op.kwidth);
+ ksize.mutable_list()->add_i(1);
+}
+
+void ConvertAveragePoolOperator(const AveragePoolOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* avgpool_op = tensorflow_graph->add_node();
+ avgpool_op->set_op("AvgPool");
+ avgpool_op->set_name(src_op.outputs[0]);
+ *avgpool_op->add_input() = src_op.inputs[0];
+ auto& strides = (*avgpool_op->mutable_attr())["strides"];
+ strides.mutable_list()->add_i(1);
+ strides.mutable_list()->add_i(src_op.stride_height);
+ strides.mutable_list()->add_i(src_op.stride_width);
+ strides.mutable_list()->add_i(1);
+ string padding;
+ if (src_op.padding.type == PaddingType::kSame) {
+ padding = "SAME";
+ } else if (src_op.padding.type == PaddingType::kValid) {
+ padding = "VALID";
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ (*avgpool_op->mutable_attr())["padding"].set_s(padding);
+ (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
+ ksize.mutable_list()->add_i(1);
+ ksize.mutable_list()->add_i(src_op.kheight);
+ ksize.mutable_list()->add_i(src_op.kwidth);
+ ksize.mutable_list()->add_i(1);
+}
+
+void ConvertConcatenationOperator(const Model& model,
+ const ConcatenationOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* dc_op = tensorflow_graph->add_node();
+ dc_op->set_op("ConcatV2");
+ dc_op->set_name(src_op.outputs[0]);
+ const string dummy_concat_dim = src_op.outputs[0] + "/concat_dim";
+ CreateDummyConcatDimTensorConst(dummy_concat_dim, src_op.concat_dim,
+ tensorflow_graph);
+ for (const auto& input : src_op.inputs) {
+ *dc_op->add_input() = input;
+ }
+ *dc_op->add_input() = dummy_concat_dim;
+ (*dc_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32);
+ (*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size());
+}
+
+void ConvertTensorFlowReshapeOperator(const Model& model,
+ const TensorFlowReshapeOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* reshape_op = tensorflow_graph->add_node();
+ reshape_op->set_op("Reshape");
+ reshape_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *reshape_op->add_input() = src_op.inputs[0];
+ *reshape_op->add_input() = src_op.inputs[1];
+ (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ const auto& shape_array = model.GetArray(src_op.inputs[1]);
+ CHECK(shape_array.data_type == ArrayDataType::kInt32);
+ CHECK(shape_array.buffer != nullptr);
+ const auto& shape_data = shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ CreateReshapeShapeTensorConst(src_op.inputs[1], shape_data, tensorflow_graph);
+}
+
+void ConvertL2PoolOperator(const L2PoolOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ const string square_output = src_op.outputs[0] + "/square";
+ const string avgpool_output = src_op.outputs[0] + "/avgpool";
+
+ auto* square_op = tensorflow_graph->add_node();
+ square_op->set_op("Square");
+ square_op->set_name(square_output);
+ *square_op->add_input() = src_op.inputs[0];
+ (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ string padding;
+ if (src_op.padding.type == PaddingType::kSame) {
+ padding = "SAME";
+ } else if (src_op.padding.type == PaddingType::kValid) {
+ padding = "VALID";
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+
+ auto* avgpool_op = tensorflow_graph->add_node();
+ avgpool_op->set_op("AvgPool");
+ avgpool_op->set_name(avgpool_output);
+ *avgpool_op->add_input() = square_output;
+ auto& strides = (*avgpool_op->mutable_attr())["strides"];
+ strides.mutable_list()->add_i(1);
+ strides.mutable_list()->add_i(src_op.stride_height);
+ strides.mutable_list()->add_i(src_op.stride_width);
+ strides.mutable_list()->add_i(1);
+
+ (*avgpool_op->mutable_attr())["padding"].set_s(padding);
+ (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
+ ksize.mutable_list()->add_i(1);
+ ksize.mutable_list()->add_i(src_op.kheight);
+ ksize.mutable_list()->add_i(src_op.kwidth);
+ ksize.mutable_list()->add_i(1);
+
+ auto* sqrt_op = tensorflow_graph->add_node();
+ sqrt_op->set_op("Sqrt");
+ sqrt_op->set_name(src_op.outputs[0]);
+ *sqrt_op->add_input() = avgpool_output;
+ (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertSquareOperator(const TensorFlowSquareOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* square_op = tensorflow_graph->add_node();
+ square_op->set_op("Square");
+ square_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *square_op->add_input() = src_op.inputs[0];
+ (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* sqrt_op = tensorflow_graph->add_node();
+ sqrt_op->set_op("Sqrt");
+ sqrt_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *sqrt_op->add_input() = src_op.inputs[0];
+ (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertSplitOperator(const Model& model,
+ const TensorFlowSplitOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* split_op = tensorflow_graph->add_node();
+ split_op->set_op("Split");
+ split_op->set_name(src_op.outputs[0]);
+ for (const auto& input : src_op.inputs) {
+ *split_op->add_input() = input;
+ }
+ (*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*split_op->mutable_attr())["num_split"].set_i(src_op.num_split);
+ const auto& split_dim_array = model.GetArray(src_op.inputs[0]);
+ CHECK(split_dim_array.buffer);
+ CHECK(split_dim_array.data_type == ArrayDataType::kInt32);
+ const auto& split_dim_data =
+ split_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
+ CHECK_EQ(split_dim_data.size(), 1);
+ const int split_dim = split_dim_data[0];
+ CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim,
+ tensorflow_graph);
+}
+
+tensorflow::DataType GetTensorFlowDataType(const Model& model,
+ const string& array_name) {
+ auto& dtype = model.GetArray(array_name).data_type;
+ CHECK(dtype == ArrayDataType::kFloat || dtype == ArrayDataType::kInt32 ||
+ dtype == ArrayDataType::kUint8);
+ if (dtype == ArrayDataType::kFloat) {
+ return tensorflow::DT_FLOAT;
+ } else if (dtype == ArrayDataType::kInt32) {
+ return tensorflow::DT_INT32;
+ } else if (dtype == ArrayDataType::kUint8) {
+ return tensorflow::DT_UINT8;
+ } else {
+ LOG(FATAL) << "Wrong data type";
+ }
+}
+
+void ConvertCastOperator(const Model& model, const CastOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* cast_op = tensorflow_graph->add_node();
+ cast_op->set_op("Cast");
+ cast_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *cast_op->add_input() = src_op.inputs[0];
+
+ (*cast_op->mutable_attr())["DstT"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
+ (*cast_op->mutable_attr())["SrcT"].set_type(
+ GetTensorFlowDataType(model, src_op.inputs[0]));
+}
+
+void ConvertFloorOperator(const Model& model, const FloorOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* floor_op = tensorflow_graph->add_node();
+ floor_op->set_op("Floor");
+ floor_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *floor_op->add_input() = src_op.inputs[0];
+ (*floor_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* gather_op = tensorflow_graph->add_node();
+ gather_op->set_op("Gather");
+ gather_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *gather_op->add_input() = src_op.inputs[0];
+ *gather_op->add_input() = src_op.inputs[1];
+
+ (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32);
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*gather_op->mutable_attr())["Tparams"].set_type(params_type);
+}
+
+void ConvertResizeBilinearOperator(const Model& model,
+ const ResizeBilinearOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* resize_op = tensorflow_graph->add_node();
+ resize_op->set_op("ResizeBilinear");
+ resize_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *resize_op->add_input() = src_op.inputs[0];
+ *resize_op->add_input() = src_op.inputs[1];
+ (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+namespace {
+// TODO(aselle): Remove when available in absl
+absl::string_view FindLongestCommonPrefix(absl::string_view a,
+ absl::string_view b) {
+ if (a.empty() || b.empty()) return absl::string_view();
+
+ const char* pa = a.data();
+ const char* pb = b.data();
+ string::difference_type count = 0;
+ const string::difference_type limit = std::min(a.size(), b.size());
+ while (count < limit && *pa == *pb) {
+ ++pa;
+ ++pb;
+ ++count;
+ }
+
+ return absl::string_view(a.data(), count);
+}
+} // namespace
+
+void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ // Find the base name
+ const string base(
+ FindLongestCommonPrefix(src_op.outputs[LstmCellOperator::STATE_OUTPUT],
+ src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]));
+
+ // Concatenate inputs
+ const string concat_output = base + "basic_lstm_cell/concat";
+ // Op names have been chosen to match the tf.slim LSTM naming
+ // as closely as possible.
+ const int concat_dim =
+ model.arrays.at(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
+ ->shape()
+ .dimensions_count() -
+ 1;
+ // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat
+ // works the same since the tensor has the same underlying data layout.
+ const string concat_dim_output = concat_output + "/concat_dim";
+ CreateDummyConcatDimTensorConst(concat_dim_output, concat_dim,
+ tensorflow_graph);
+ auto* concat_op = tensorflow_graph->add_node();
+ concat_op->set_op("ConcatV2");
+ concat_op->set_name(concat_output);
+ *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT];
+ *concat_op->add_input() = src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT];
+ *concat_op->add_input() = concat_dim_output;
+ (*concat_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*concat_op->mutable_attr())["Tidx"].set_type(DT_INT32);
+ (*concat_op->mutable_attr())["N"].set_i(2); // Number of inputs
+
+ // Write weights
+ const string weights_output = base + "weights";
+ CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
+ const auto& weights_array =
+ *model.arrays.at(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
+ // Convert 4D FullyConnected weights into 2D matrix
+ const auto& weights_shape = weights_array.shape();
+ CHECK_EQ(weights_shape.dimensions_count(), 2);
+ CHECK(weights_array.buffer);
+ CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
+ const float* weights_data =
+ weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ConvertFloatTensorConst(weights_output, weights_shape, weights_data,
+ AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph);
+
+ // Fully connected matrix multiply
+ const string matmul_output = base + "MatMul";
+ auto* matmul_op = tensorflow_graph->add_node();
+ matmul_op->set_op("MatMul");
+ matmul_op->set_name(matmul_output);
+ *matmul_op->add_input() = concat_output;
+ *matmul_op->add_input() = weights_output;
+ (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
+ (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
+ (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ // Write biases
+ const string biases_output = base + "biases";
+ CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
+ const auto& bias_array =
+ *model.arrays.at(src_op.inputs[LstmCellOperator::BIASES_INPUT]);
+ // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
+ Shape bias_shape_1d = bias_array.shape();
+ UnextendShape(&bias_shape_1d, 1);
+ CHECK(bias_array.buffer);
+ CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
+ const float* bias_data =
+ bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ConvertFloatTensorConst(biases_output, bias_shape_1d, bias_data,
+ AxesOrder::kOneAxis, AxesOrder::kOneAxis,
+ tensorflow_graph,
+ LegacyScalarPolicy::kDoCreateLegacyScalars);
+
+ // Add biases
+ string biasadd_output = base + "BiasAdd";
+ auto* biasadd_op = tensorflow_graph->add_node();
+ biasadd_op->set_op("BiasAdd");
+ biasadd_op->set_name(biasadd_output);
+ biasadd_op->add_input(matmul_output);
+ biasadd_op->add_input(biases_output);
+ (*biasadd_op->mutable_attr())["data_format"].set_s("NHWC");
+ (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ // Split
+ string split_dim_output = base + "split/split_dim";
+ // The dimension is the same as the concatenation dimension
+ CreateDummyConcatDimTensorConst(split_dim_output, concat_dim,
+ tensorflow_graph);
+ string split_output = base + "split";
+ auto* split_op = tensorflow_graph->add_node();
+ split_op->set_op("Split");
+ split_op->set_name(split_output);
+ *split_op->add_input() = split_dim_output;
+ *split_op->add_input() = biasadd_output;
+ (*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*split_op->mutable_attr())["num_split"].set_i(4); // Split into four outputs
+
+ // Activation functions and memory computations
+ const string tanh_0_output = base + "Tanh";
+ auto* tanh_0_op = tensorflow_graph->add_node();
+ tanh_0_op->set_op("Tanh");
+ tanh_0_op->set_name(tanh_0_output);
+ *tanh_0_op->add_input() = split_output + ":1";
+ (*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string sigmoid_1_output = base + "Sigmoid_1";
+ auto* logistic_1_op = tensorflow_graph->add_node();
+ logistic_1_op->set_op("Sigmoid");
+ logistic_1_op->set_name(sigmoid_1_output);
+ *logistic_1_op->add_input() = split_output;
+ (*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string mul_1_output = base + "mul_1";
+ auto* mul_1_op = tensorflow_graph->add_node();
+ mul_1_op->set_op("Mul");
+ mul_1_op->set_name(mul_1_output);
+ *mul_1_op->add_input() = sigmoid_1_output;
+ *mul_1_op->add_input() = tanh_0_output;
+ (*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string sigmoid_0_output = base + "Sigmoid";
+ auto* logistic_2_op = tensorflow_graph->add_node();
+ logistic_2_op->set_op("Sigmoid");
+ logistic_2_op->set_name(sigmoid_0_output);
+ *logistic_2_op->add_input() = split_output + ":2";
+ (*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string sigmoid_2_output = base + "Sigmoid_2";
+ auto* logistic_3_op = tensorflow_graph->add_node();
+ logistic_3_op->set_op("Sigmoid");
+ logistic_3_op->set_name(sigmoid_2_output);
+ *logistic_3_op->add_input() = split_output + ":3";
+ (*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string mul_0_output = base + "mul";
+ auto* mul_0_op = tensorflow_graph->add_node();
+ mul_0_op->set_op("Mul");
+ mul_0_op->set_name(mul_0_output);
+ *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT];
+ *mul_0_op->add_input() = sigmoid_0_output;
+ (*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string add_1_output = src_op.outputs[LstmCellOperator::STATE_OUTPUT];
+ auto* add_1_op = tensorflow_graph->add_node();
+ add_1_op->set_op("Add");
+ add_1_op->set_name(add_1_output);
+ *add_1_op->add_input() = mul_0_output;
+ *add_1_op->add_input() = mul_1_output;
+ (*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string tanh_1_output = base + "Tanh_1";
+ auto* tanh_1_op = tensorflow_graph->add_node();
+ tanh_1_op->set_op("Tanh");
+ tanh_1_op->set_name(tanh_1_output);
+ *tanh_1_op->add_input() = add_1_output;
+ (*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string mul_2_output = src_op.outputs[LstmCellOperator::ACTIV_OUTPUT];
+ auto* mul_2_op = tensorflow_graph->add_node();
+ mul_2_op->set_op("Mul");
+ mul_2_op->set_name(mul_2_output);
+ *mul_2_op->add_input() = tanh_1_output;
+ *mul_2_op->add_input() = sigmoid_2_output;
+ (*mul_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertSpaceToBatchNDOperator(const Model& model,
+ const SpaceToBatchNDOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("SpaceToBatchND");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 3);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+ *new_op->add_input() = src_op.inputs[2];
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+ (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
+ (*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32);
+}
+
+void ConvertBatchToSpaceNDOperator(const Model& model,
+ const BatchToSpaceNDOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("BatchToSpaceND");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 3);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+ *new_op->add_input() = src_op.inputs[2];
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+ (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
+ (*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32);
+}
+
+void ConvertPadOperator(const Model& model, const PadOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("Pad");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+
+ // Create the params tensor.
+ auto* params_op = tensorflow_graph->add_node();
+ params_op->set_op("Const");
+ params_op->set_name(src_op.inputs[1]);
+ (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+
+ CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
+ for (int i = 0; i < src_op.left_padding.size(); ++i) {
+ tensor->add_int_val(src_op.left_padding[i]);
+ tensor->add_int_val(src_op.right_padding[i]);
+ }
+ auto* shape = tensor->mutable_tensor_shape();
+ shape->add_dim()->set_size(src_op.left_padding.size());
+ shape->add_dim()->set_size(2);
+}
+
+void CreateSliceInput(const string& input_name, const std::vector<int>& values,
+ GraphDef* tensorflow_graph) {
+ auto* params_op = tensorflow_graph->add_node();
+ params_op->set_op("Const");
+ params_op->set_name(input_name);
+ (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+
+ for (int i = 0; i < values.size(); ++i) {
+ tensor->add_int_val(values[i]);
+ }
+ auto* shape = tensor->mutable_tensor_shape();
+ shape->add_dim()->set_size(values.size());
+}
+
+void ConvertStridedSliceOperator(const Model& model,
+ const StridedSliceOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("StridedSlice");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 4);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+ *new_op->add_input() = src_op.inputs[2];
+ *new_op->add_input() = src_op.inputs[3];
+
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+
+ (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
+ (*new_op->mutable_attr())["begin_mask"].set_i(src_op.begin_mask);
+ (*new_op->mutable_attr())["ellipsis_mask"].set_i(src_op.ellipsis_mask);
+ (*new_op->mutable_attr())["end_mask"].set_i(src_op.end_mask);
+ (*new_op->mutable_attr())["new_axis_mask"].set_i(src_op.new_axis_mask);
+ (*new_op->mutable_attr())["shrink_axis_mask"].set_i(src_op.shrink_axis_mask);
+
+ // Create tensors for start/stop indices and strides.
+ CreateSliceInput(src_op.inputs[1], src_op.start_indices, tensorflow_graph);
+ CreateSliceInput(src_op.inputs[2], src_op.stop_indices, tensorflow_graph);
+ CreateSliceInput(src_op.inputs[3], src_op.strides, tensorflow_graph);
+}
+
+void ConvertSliceOperator(const Model& model, const SliceOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("Slice");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 3);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+ *new_op->add_input() = src_op.inputs[2];
+
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+ (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
+
+ // Create tensors for begin and size inputs.
+ CreateSliceInput(src_op.inputs[1], src_op.begin, tensorflow_graph);
+ CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph);
+}
+
+void ConvertMeanOperator(const Model& model, const MeanOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("Mean");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+
+ // Create the params tensor.
+ auto* params_op = tensorflow_graph->add_node();
+ params_op->set_op("Const");
+ params_op->set_name(src_op.inputs[1]);
+ (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+
+ for (int i = 0; i < src_op.reduction_indices.size(); ++i) {
+ tensor->add_int_val(src_op.reduction_indices[i]);
+ }
+ auto* shape = tensor->mutable_tensor_shape();
+ shape->add_dim()->set_size(src_op.reduction_indices.size());
+}
+
+void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("Squeeze");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *new_op->add_input() = src_op.inputs[0];
+
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+
+ auto& squeeze_dims = (*new_op->mutable_attr())["squeeze_dims"];
+ for (int i : src_op.squeeze_dims) {
+ squeeze_dims.mutable_list()->add_i(i);
+ }
+}
+
+void ConvertSubOperator(const Model& model, const SubOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* sub_op = tensorflow_graph->add_node();
+ sub_op->set_op("Sub");
+ sub_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *sub_op->add_input() = src_op.inputs[0];
+ *sub_op->add_input() = src_op.inputs[1];
+ const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*sub_op->mutable_attr())["T"].set_type(data_type);
+}
+
+void ConvertTensorFlowMinimumOperator(const Model& model,
+ const TensorFlowMinimumOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* sub_op = tensorflow_graph->add_node();
+ sub_op->set_op("Minimum");
+ sub_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *sub_op->add_input() = src_op.inputs[0];
+ *sub_op->add_input() = src_op.inputs[1];
+ const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*sub_op->mutable_attr())["T"].set_type(data_type);
+}
+
+void ConvertTensorFlowMaximumOperator(const Model& model,
+ const TensorFlowMaximumOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* sub_op = tensorflow_graph->add_node();
+ sub_op->set_op("Maximum");
+ sub_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *sub_op->add_input() = src_op.inputs[0];
+ *sub_op->add_input() = src_op.inputs[1];
+ const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*sub_op->mutable_attr())["T"].set_type(data_type);
+}
+
+void ConvertOperator(const Model& model, const Operator& src_op,
+ GraphDef* tensorflow_graph) {
+ if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
+ LOG(FATAL)
+ << "Unsupported: the input model has a fused activation function";
+ }
+
+ if (src_op.type == OperatorType::kConv) {
+ ConvertConvOperator(model, static_cast<const ConvOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kDepthwiseConv) {
+ ConvertDepthwiseConvOperator(
+ model, static_cast<const DepthwiseConvOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kDepthToSpace) {
+ ConvertDepthToSpaceOperator(
+ model, static_cast<const DepthToSpaceOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSpaceToDepth) {
+ ConvertSpaceToDepthOperator(
+ model, static_cast<const SpaceToDepthOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kFullyConnected) {
+ ConvertFullyConnectedOperator(
+ model, static_cast<const FullyConnectedOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kAdd) {
+ ConvertAddOperator(model, static_cast<const AddOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kMul) {
+ ConvertMulOperator(model, static_cast<const MulOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kRelu) {
+ ConvertReluOperator(static_cast<const ReluOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kRelu1) {
+ ConvertRelu1Operator(static_cast<const Relu1Operator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kRelu6) {
+ ConvertRelu6Operator(static_cast<const Relu6Operator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kLogistic) {
+ ConvertLogisticOperator(static_cast<const LogisticOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTanh) {
+ ConvertTanhOperator(static_cast<const TanhOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kL2Normalization) {
+ ConvertL2NormalizationOperator(
+ static_cast<const L2NormalizationOperator&>(src_op), tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSoftmax) {
+ ConvertSoftmaxOperator(model, static_cast<const SoftmaxOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kLocalResponseNormalization) {
+ ConvertLocalResponseNormalizationOperator(
+ static_cast<const LocalResponseNormalizationOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kLstmCell) {
+ ConvertLstmCellOperator(model, static_cast<const LstmCellOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kMaxPool) {
+ ConvertMaxPoolOperator(static_cast<const MaxPoolOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kAveragePool) {
+ ConvertAveragePoolOperator(static_cast<const AveragePoolOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kConcatenation) {
+ ConvertConcatenationOperator(
+ model, static_cast<const ConcatenationOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowReshape) {
+ ConvertTensorFlowReshapeOperator(
+ model, static_cast<const TensorFlowReshapeOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kL2Pool) {
+ ConvertL2PoolOperator(static_cast<const L2PoolOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowSquare) {
+ ConvertSquareOperator(static_cast<const TensorFlowSquareOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowSqrt) {
+ ConvertSqrtOperator(static_cast<const TensorFlowSqrtOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowSplit) {
+ ConvertSplitOperator(model,
+ static_cast<const TensorFlowSplitOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kFakeQuant) {
+ ConvertFakeQuantOperator(static_cast<const FakeQuantOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kCast) {
+ ConvertCastOperator(model, static_cast<const CastOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kFloor) {
+ ConvertFloorOperator(model, static_cast<const FloorOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kGather) {
+ ConvertGatherOperator(model, static_cast<const GatherOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kResizeBilinear) {
+ ConvertResizeBilinearOperator(
+ model, static_cast<const ResizeBilinearOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSpaceToBatchND) {
+ ConvertSpaceToBatchNDOperator(
+ model, static_cast<const SpaceToBatchNDOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kBatchToSpaceND) {
+ ConvertBatchToSpaceNDOperator(
+ model, static_cast<const BatchToSpaceNDOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kPad) {
+ ConvertPadOperator(model, static_cast<const PadOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kStridedSlice) {
+ ConvertStridedSliceOperator(
+ model, static_cast<const StridedSliceOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kMean) {
+ ConvertMeanOperator(model, static_cast<const MeanOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSub) {
+ ConvertSubOperator(model, static_cast<const SubOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowMinimum) {
+ ConvertTensorFlowMinimumOperator(
+ model, static_cast<const TensorFlowMinimumOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowMaximum) {
+ ConvertTensorFlowMaximumOperator(
+ model, static_cast<const TensorFlowMaximumOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSqueeze) {
+ ConvertSqueezeOperator(model, static_cast<const SqueezeOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSlice) {
+ ConvertSliceOperator(model, static_cast<const SliceOperator&>(src_op),
+ tensorflow_graph);
+ } else {
+ LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
+ }
+}
+
+void AddPlaceholder(const string& name, GraphDef* tensorflow_graph) {
+ auto* placeholder = tensorflow_graph->add_node();
+ placeholder->set_op("Placeholder");
+ (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
+ placeholder->set_name(name);
+}
+
+void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
+ GraphDef* tensorflow_graph) {
+ auto* placeholder = tensorflow_graph->add_node();
+ placeholder->set_op("Placeholder");
+ placeholder->set_name(name);
+ (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
+
+ auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape();
+ const auto& state_array = *model.arrays.at(name);
+ if (state_array.has_shape()) {
+ const auto& state_shape = state_array.shape();
+ const int kDims = state_shape.dimensions_count();
+ for (int i = 0; i < kDims; ++i) {
+ shape->add_dim()->set_size(state_shape.dims(i));
+ }
+ } else {
+ shape->add_dim()->set_size(1);
+ shape->add_dim()->set_size(size);
+ }
+}
+
+void ExportTensorFlowGraphDefImplementation(const Model& model,
+ GraphDef* tensorflow_graph) {
+ for (const auto& input_array : model.flags.input_arrays()) {
+ AddPlaceholder(input_array.name(), tensorflow_graph);
+ }
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),
+ tensorflow_graph);
+ }
+ for (const auto& op : model.operators) {
+ ConvertOperator(model, *op, tensorflow_graph);
+ }
+ // Generically export arrays that haven't been exported already
+ // by the above operators export. It's important that this comes
+ // after, as some operators need to export arrays that they reference
+ // in a specific way, rather than in the generic way done below.
+ for (const auto& array_pair : model.arrays) {
+ const string& array_name = array_pair.first;
+ const auto& array = *array_pair.second;
+ if (array.buffer) {
+ switch (array.data_type) {
+ case ArrayDataType::kFloat:
+ ConvertFloatTensorConst(model, array_name, tensorflow_graph);
+ break;
+ case ArrayDataType::kInt32:
+ ConvertIntTensorConst(model, array_name, tensorflow_graph);
+ break;
+ default:
+ break;
+ }
+ }
+ }
+}
+} // namespace
+
+void ExportTensorFlowGraphDef(const Model& model,
+ string* output_file_contents) {
+ CHECK(output_file_contents->empty());
+ GraphDef tensorflow_graph;
+ ExportTensorFlowGraphDefImplementation(model, &tensorflow_graph);
+ LogDumpGraphDef(kLogLevelModelChanged, "AT EXPORT", tensorflow_graph);
+ CHECK(tensorflow_graph.SerializeToString(output_file_contents));
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.h b/tensorflow/contrib/lite/toco/export_tensorflow.h
new file mode 100644
index 0000000000..eca9774576
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.h
@@ -0,0 +1,27 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_
+
+#include <string>
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+void ExportTensorFlowGraphDef(const Model& model, string* output_file_contents);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_
diff --git a/tensorflow/contrib/lite/toco/format_port.h b/tensorflow/contrib/lite/toco/format_port.h
new file mode 100644
index 0000000000..3bc3295d04
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/format_port.h
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This file is used to provide equivalents of internal util::format::FormatF
+// and util::format::AppendF. Unfortunately, type safety is not as good as a
+// a full C++ example.
+// TODO(aselle): When absl adds support for StrFormat, use that instead.
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_
+
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+namespace toco {
+namespace port {
+
+/// Identity (default case)
+template <class T>
+T IdentityOrConvertStringToRaw(T foo) {
+ return foo;
+}
+
+// Overloaded case where we return std::string.
+inline const char* IdentityOrConvertStringToRaw(const std::string& foo) {
+ return foo.c_str();
+}
+
+#if defined(PLATFORM_GOOGLE)
+// Overloaded case where we return string.
+inline const char* IdentityOrConvertStringToRaw(const string& foo) {
+ return foo.c_str();
+}
+#endif // PLATFORM_GOOGLE
+// Delegate to TensorFlow Appendf function until absl has an equivalent.
+template <typename... Args>
+inline void AppendFHelper(string* destination, const char* fmt,
+ Args&&... args) {
+ tensorflow::strings::Appendf(destination, fmt, args...);
+}
+
+// Specialization for no argument format string (avoid security bug).
+inline void AppendFHelper(string* destination, const char* fmt) {
+ tensorflow::strings::Appendf(destination, "%s", fmt);
+}
+
+// Append formatted string (with format fmt and args args) to the string
+// pointed to by destination. fmt follows C printf semantics.
+// One departure is that %s can be driven by a std::string or string.
+template <typename... Args>
+inline void AppendF(string* destination, const char* fmt, Args&&... args) {
+ AppendFHelper(destination, fmt, IdentityOrConvertStringToRaw(args)...);
+}
+
+// Return formatted string (with format fmt and args args). fmt follows C printf
+// semantics. One departure is that %s can be driven by a std::string or string.
+template <typename... Args>
+inline string StringF(const char* fmt, Args&&... args) {
+ string result;
+ AppendFHelper(&result, fmt, IdentityOrConvertStringToRaw(args)...);
+ return result;
+}
+
+} // namespace port
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
new file mode 100644
index 0000000000..bf454c40c7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
@@ -0,0 +1,98 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
+ auto conv_it = model->operators.begin() + op_index;
+ if (conv_it->get()->type != OperatorType::kConv) {
+ return false;
+ }
+ const auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
+ if (conv_op->stride_width != conv_op->stride_height) {
+ return false;
+ }
+ auto& weights_array = model->GetArray(conv_op->inputs[1]);
+ if (!weights_array.buffer) {
+ // Yield until the weights are resolved as a constant array.
+ return false;
+ }
+ if (weights_array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ if (weights_array.shape().dims(3) != 1) {
+ // Not a pure convolution: Conv does accumulation across the depth
+ // dimension.
+ return false;
+ }
+ // At this point we know we have a pure conv. Rewrite it as DepthwiseConv.
+ AddMessageF(
+ "%s is purely convolutional (input/weights depth is 1), replacing it by "
+ "a DepthwiseConv.",
+ LogName(*conv_op));
+ auto* depthwiseconv_op = new DepthwiseConvOperator;
+ // Conv and DepthwiseConv take the same inputs
+ depthwiseconv_op->inputs = conv_op->inputs;
+ // Conv may have a 2nd output for im2col
+ depthwiseconv_op->outputs = {conv_op->outputs[0]};
+ if (conv_op->outputs.size() > 1) {
+ // delete the im2col array.
+ model->arrays.erase(conv_op->outputs[1]);
+ }
+ depthwiseconv_op->fused_activation_function =
+ conv_op->fused_activation_function;
+ // Let PropagateFixedSizes recompute fixed padding, just in case some day it
+ // may be different for Conv vs DepthwiseConv.
+ depthwiseconv_op->padding.type = conv_op->padding.type;
+ depthwiseconv_op->stride_height = conv_op->stride_height;
+ depthwiseconv_op->stride_width = conv_op->stride_width;
+ depthwiseconv_op->depth_multiplier = weights_array.shape().dims(0);
+ // Replace the operator in the graph.
+ const auto depthwiseconv_it =
+ model->operators.emplace(conv_it, depthwiseconv_op);
+ conv_it = depthwiseconv_it + 1;
+ CHECK_EQ(conv_it->get(), conv_op);
+ model->operators.erase(conv_it);
+ // Shuffle the weights.
+ const auto& weights_shape = weights_array.shape();
+ auto& weights_buffer =
+ weights_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ const std::vector<float>& conv_weights_data = weights_buffer.data;
+ std::vector<float> depthwise_conv_weights_data(conv_weights_data.size());
+ const int depth = weights_shape.dims(0);
+ const int width = weights_shape.dims(1);
+ const int height = weights_shape.dims(2);
+ const int width_height = width * height;
+ for (int c = 0; c < depth; c++) {
+ for (int xy = 0; xy < width_height; xy++) {
+ depthwise_conv_weights_data[c + depth * xy] =
+ conv_weights_data[xy + width_height * c];
+ }
+ }
+ *weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth};
+ weights_buffer.data = depthwise_conv_weights_data;
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
new file mode 100644
index 0000000000..1735b51e5b
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
+ auto conv_it = model->operators.begin() + op_index;
+ if (conv_it->get()->type != OperatorType::kConv) {
+ return false;
+ }
+ auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
+ if (conv_op->outputs.size() == 2) {
+ // We already have an im2col array
+ return false;
+ }
+ const auto& weights_array = *model->arrays[conv_op->inputs[1]];
+ if (!weights_array.has_shape()) {
+ // We need to yield until weights dims have been resolved, because
+ // from the weights dims we determine whether an im2col array is
+ // needed.
+ return false;
+ }
+ const auto& weights_shape = weights_array.shape();
+ const int kheight = weights_shape.dims(1);
+ const int kwidth = weights_shape.dims(2);
+ if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 &&
+ conv_op->stride_height == 1) {
+ // 1x1 unstrided conv does not need an im2col array.
+ return false;
+ }
+
+ // Create the im2col array.
+ CHECK_EQ(conv_op->outputs.size(), 1);
+ const string& im2col_array_name =
+ AvailableArrayName(*model, conv_op->inputs[0] + "_im2col");
+ model->GetOrCreateArray(im2col_array_name);
+ conv_op->outputs.push_back(im2col_array_name);
+ AddMessageF(
+ "Created an im2col array for %s, with %dx%d kernel and stride_width=%d, "
+ "stride_height=%d",
+ LogName(*conv_op), kwidth, kheight, conv_op->stride_width,
+ conv_op->stride_height);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
new file mode 100644
index 0000000000..b89e3f5310
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
@@ -0,0 +1,223 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+template <ArrayDataType A>
+void DequantizeBuffer(Array* array) {
+ const auto old_data = array->GetBuffer<A>().data;
+ array->buffer = nullptr;
+ array->data_type = ArrayDataType::kFloat;
+ auto& new_data = array->GetMutableBuffer<ArrayDataType::kFloat>().data;
+ new_data.resize(old_data.size());
+ const auto& qparams = array->GetQuantizationParams();
+ for (int i = 0; i < old_data.size(); i++) {
+ new_data[i] = qparams.scale * (old_data[i] - qparams.zero_point);
+ }
+}
+
+std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput(
+ Model* model, const string& array_name) {
+ for (auto it = model->operators.begin(); it != model->operators.end(); ++it) {
+ for (const auto& input : it->get()->inputs) {
+ if (input == array_name) {
+ return it;
+ }
+ }
+ }
+ return model->operators.end();
+}
+
+void ClearArrayQuantizationParams(const string& array_name, Model* model) {
+ auto* array = model->arrays.at(array_name).get();
+ CHECK(array->quantization_params);
+ for (auto& input_array : *model->flags.mutable_input_arrays()) {
+ if (input_array.name() == array_name) {
+ auto& qparams = *array->quantization_params;
+ const double new_std_value = 1. / qparams.scale;
+ const double new_mean_value = qparams.zero_point;
+ if (input_array.has_std_value()) {
+ CHECK_LE(std::abs(new_std_value - input_array.std_value()), 0.001);
+ } else {
+ input_array.set_std_value(new_std_value);
+ }
+ if (input_array.has_mean_value()) {
+ CHECK_LE(std::abs(new_mean_value - input_array.mean_value()), 0.001);
+ } else {
+ input_array.set_mean_value(new_mean_value);
+ }
+ }
+ }
+ array->quantization_params = nullptr;
+}
+
+bool DequantizeArray(const string& array_name,
+ GraphTransformation* transformation, Model* model) {
+ auto* array = model->arrays.at(array_name).get();
+ if (!array->quantization_params) {
+ return false;
+ }
+ transformation->AddMessageF("Dequantizing array: %s", array_name);
+
+ // Dequantize any buffer
+ if (array->buffer) {
+ if (array->data_type == ArrayDataType::kUint8) {
+ DequantizeBuffer<ArrayDataType::kUint8>(array);
+ } else if (array->data_type == ArrayDataType::kInt32) {
+ DequantizeBuffer<ArrayDataType::kInt32>(array);
+ } else {
+ LOG(FATAL) << "Unhandled data type";
+ }
+ CHECK(array->data_type == ArrayDataType::kFloat);
+ CHECK(array->buffer->type == ArrayDataType::kFloat);
+
+ // Clear quantization params, officially makes this a non-quantized array.
+ ClearArrayQuantizationParams(array_name, model);
+ return true;
+ } else {
+ array->data_type = ArrayDataType::kFloat;
+ }
+
+ // Clear quantization params, officially makes this a non-quantized array.
+ ClearArrayQuantizationParams(array_name, model);
+
+ if (array->buffer) {
+ return true;
+ }
+
+ auto* op_outputting_array = GetOpWithOutput(*model, array_name);
+ if (op_outputting_array) {
+ if (op_outputting_array->type == OperatorType::kTensorFlowReshape) {
+ return true;
+ }
+ }
+
+ // If there was no minmax info, we can return now. Indeed,
+ // the below only serves to create a FakeQuant node, but some arrays are
+ // quantized without MinMax (see the CHECK above) and that corresponds to
+ // places where a FakeQuant node is actually not wanted, because the
+ // quantization params are meant to be inferred in another way (e.g. bias
+ // vector for a Conv op, see their special-casing in quantize.cc).
+ if (!array->minmax) {
+ return true;
+ }
+
+ // Determine whether to insert a FakeQuant before or after
+ // this array.
+ bool must_insert_fakequant_before = false;
+ bool must_insert_fakequant_after = false;
+ if (IsInputArray(*model, array_name)) {
+ must_insert_fakequant_after = true;
+ }
+ for (const string& output_array : model->flags.output_arrays()) {
+ if (array_name == output_array) {
+ must_insert_fakequant_before = true;
+ }
+ }
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (array_name == rnn_state.state_array()) {
+ must_insert_fakequant_after = true;
+ }
+ if (array_name == rnn_state.back_edge_source_array()) {
+ must_insert_fakequant_before = true;
+ }
+ }
+ CHECK(!(must_insert_fakequant_before && must_insert_fakequant_after));
+
+ // Create and insert the FakeQuant node
+ auto* fakequant_op = new FakeQuantOperator;
+ model->operators.emplace(FindFirstOpWithInput(model, array_name),
+ fakequant_op);
+ const string& new_array_name = AvailableArrayName(*model, array_name);
+ auto& new_array = model->GetOrCreateArray(new_array_name);
+ new_array.data_type = ArrayDataType::kFloat;
+ new_array.copy_shape(array->shape());
+ new_array.GetOrCreateMinMax() = array->GetMinMax();
+ fakequant_op->minmax.reset(new MinMax);
+ *fakequant_op->minmax = array->GetMinMax();
+ if (must_insert_fakequant_before) {
+ for (const auto& op : model->operators) {
+ for (string& output : op->outputs) {
+ if (output == array_name) {
+ output = new_array_name;
+ }
+ }
+ }
+ fakequant_op->inputs = {new_array_name};
+ fakequant_op->outputs = {array_name};
+ } else {
+ for (const auto& op : model->operators) {
+ for (string& input : op->inputs) {
+ if (input == array_name) {
+ input = new_array_name;
+ }
+ }
+ }
+ fakequant_op->inputs = {array_name};
+ fakequant_op->outputs = {new_array_name};
+ }
+ return true;
+}
+
+} // namespace
+
+bool Dequantize::Run(Model* model, std::size_t op_index) {
+ const auto op_it = model->operators.begin() + op_index;
+ auto* op = op_it->get();
+
+ if (op->type == OperatorType::kDequantize) {
+ auto& input_array = model->GetArray(op->inputs[0]);
+ if (input_array.data_type == ArrayDataType::kFloat) {
+ return false;
+ }
+ if (input_array.final_data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ input_array.data_type = ArrayDataType::kFloat;
+ input_array.quantization_params = nullptr;
+ auto& output_array = model->GetArray(op->outputs[0]);
+ output_array.data_type = ArrayDataType::kFloat;
+ output_array.quantization_params = nullptr;
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+ }
+
+ std::vector<string> arrays;
+ for (const string& input : op->inputs) {
+ arrays.push_back(input);
+ }
+ for (const string& output : op->outputs) {
+ arrays.push_back(output);
+ }
+ bool changed = false;
+ for (const string& array : arrays) {
+ changed |= DequantizeArray(array, this, model);
+ }
+
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
new file mode 100644
index 0000000000..fea360740f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool DropFakeQuant::Run(Model* model, std::size_t op_index) {
+ const auto fakequant_it = model->operators.begin() + op_index;
+ auto* fakequant_base_op = fakequant_it->get();
+ if (fakequant_base_op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
+
+ if (!fakequant_op->minmax) {
+ return false;
+ }
+
+ const auto& output_array = model->GetArray(fakequant_op->outputs[0]);
+ if (!output_array.minmax) {
+ return false;
+ }
+
+ // Drop min/max inputs
+ for (int i = 1; i < fakequant_op->inputs.size(); i++) {
+ if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
+ model->arrays.erase(fakequant_op->inputs[i]);
+ }
+ }
+ fakequant_op->inputs.resize(1);
+
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
new file mode 100644
index 0000000000..a3ed6663bc
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool DropIm2colArrays::Run(Model* model, std::size_t op_index) {
+ auto conv_it = model->operators.begin() + op_index;
+ if (conv_it->get()->type != OperatorType::kConv) {
+ return false;
+ }
+ auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
+ if (conv_op->outputs.size() < 2) {
+ // Conv op does not have im2col.
+ return false;
+ }
+
+ // Drop the im2col array.
+ CHECK_EQ(conv_op->outputs.size(), 2);
+ model->arrays.erase(conv_op->outputs[1]);
+ conv_op->outputs.resize(1);
+ AddMessageF("Dropped an im2col array for %s", LogName(*conv_op));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
new file mode 100644
index 0000000000..badefeca88
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
@@ -0,0 +1,57 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool ProcessLinearOperator(Model* model, Operator* op) {
+ if (op->inputs.size() >= 3) {
+ return false;
+ }
+ const string& output_name = op->outputs[0];
+ const string& bias_name = AvailableArrayName(*model, output_name + "_bias");
+ op->inputs.push_back(bias_name);
+ DCHECK_EQ(op->inputs.size(), 3);
+ auto& bias_array = model->GetOrCreateArray(bias_name);
+ bias_array.data_type = ArrayDataType::kFloat;
+
+ return true;
+}
+} // namespace
+
+bool EnsureBiasVectors::Run(Model* model, std::size_t op_index) {
+ auto* op = model->operators[op_index].get();
+ if (op->type == OperatorType::kConv ||
+ op->type == OperatorType::kDepthwiseConv ||
+ op->type == OperatorType::kFullyConnected) {
+ if (ProcessLinearOperator(model, op)) {
+ AddMessageF("Added bias vector to %s", LogName(*op));
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
new file mode 100644
index 0000000000..7a86510025
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
@@ -0,0 +1,98 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
+ const auto ac_it = model->operators.begin() + op_index;
+ const auto* ac_op = ac_it->get();
+
+ if (ac_op->type != OperatorType::kRelu6 &&
+ ac_op->type != OperatorType::kRelu1 &&
+ ac_op->type != OperatorType::kRelu) {
+ return false;
+ }
+
+ // Find the op producing the array passed to this activation function
+ Operator* op = GetOpWithOutput(*model, ac_op->inputs[0]);
+
+ if (!op) return false;
+
+ if (CountTrueOutputs(*model, *op) > 1) {
+ AddMessageF(
+ "Not fusing activation function into %s because it has more than one "
+ " consumed output",
+ LogName(*op));
+ return false;
+ }
+
+ CHECK_EQ(op->outputs[0], ac_op->inputs[0]);
+
+ int count_ops_consuming_output = CountOpsWithInput(*model, ac_op->inputs[0]);
+ DCHECK_GE(count_ops_consuming_output, 1);
+ if (count_ops_consuming_output > 1) {
+ AddMessageF(
+ "Not fusing activation function into %s because it is consumed by more "
+ "than 1 other operator",
+ LogName(*op));
+ return false;
+ }
+
+ if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
+ AddMessageF(
+ "Not fusing activation function into %s because it already has a fused "
+ "activation function",
+ LogName(*op));
+ return false;
+ }
+
+ // TODO(dkalenichenko): Great many ops don't support activation function
+ // fusing. Switch to the whilelist approach instead.
+ if (op->type == OperatorType::kConcatenation ||
+ op->type == OperatorType::kSlice) {
+ AddMessageF(
+ "Not fusing activation function because the %s op doesn't support it",
+ LogName(*op));
+ return false;
+ }
+
+ AddMessageF("Fusing activation function %s into the preceding %s",
+ LogName(*ac_op), LogName(*op));
+ if (ac_op->type == OperatorType::kRelu6) {
+ op->fused_activation_function = FusedActivationFunctionType::kRelu6;
+ } else if (ac_op->type == OperatorType::kRelu1) {
+ op->fused_activation_function = FusedActivationFunctionType::kRelu1;
+ } else if (ac_op->type == OperatorType::kRelu) {
+ op->fused_activation_function = FusedActivationFunctionType::kRelu;
+ } else {
+ LOG(FATAL) << "Unhandled activation function type";
+ }
+ model->arrays.erase(ac_op->inputs[0]);
+ op->outputs[0] = ac_op->outputs[0];
+ model->operators.erase(ac_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
new file mode 100644
index 0000000000..4619d8bbee
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
@@ -0,0 +1,300 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void FuseAddOrSubParamsIntoFollowingAffine(Model* model, Operator* following_op,
+ const Operator* add_or_sub_op,
+ int index_of_constant_input) {
+ CHECK(add_or_sub_op->type == OperatorType::kAdd ||
+ add_or_sub_op->type == OperatorType::kSub);
+ CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
+ // If the op is a subtraction, the constant input should be the right hand
+ // side.
+ // This should have been checked before this point.
+ CHECK(add_or_sub_op->type != OperatorType::kSub ||
+ index_of_constant_input == 1);
+ if (following_op->inputs.size() < 3) {
+ LOG(FATAL) << "Missing bias parameter";
+ }
+ const auto& weights = model->GetArray(following_op->inputs[1]);
+ auto& bias = model->GetArray(following_op->inputs[2]);
+ bias.minmax = nullptr;
+ const auto& operand =
+ model->GetArray(add_or_sub_op->inputs[index_of_constant_input]);
+ // We're only supporting the case of a scalar operand. Should have
+ // been checked earlier.
+ CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1);
+
+ const float scalar_operand =
+ operand.GetBuffer<ArrayDataType::kFloat>().data[0];
+ // At this point we reduce the case of subtraction to that of addition
+ // by negating the operand.
+ float add_scalar_operand = 0.f;
+ if (add_or_sub_op->type == OperatorType::kAdd) {
+ add_scalar_operand = scalar_operand;
+ } else if (add_or_sub_op->type == OperatorType::kSub &&
+ index_of_constant_input == 1) {
+ add_scalar_operand = -scalar_operand;
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ // From here on we are fusing an addition. add_or_sub_op->type does not
+ // matter anymore.
+
+ const Shape& weights_shape = weights.shape();
+ const Shape& bias_shape = bias.shape();
+ const auto& weights_buffer = weights.GetBuffer<ArrayDataType::kFloat>();
+ const float* const weights_data = weights_buffer.data.data();
+ auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>();
+ float* const bias_data = bias_buffer.data.data();
+
+ if (following_op->type == OperatorType::kConv ||
+ following_op->type == OperatorType::kFullyConnected) {
+ const int output_depth = weights_shape.dims(0);
+ // TODO(b/62904716): Bias array should become 1-D when padding removed.
+ CHECK_EQ(output_depth, bias_shape.dims(bias_shape.dimensions_count() - 1));
+ const int weights_size = RequiredBufferSizeForShape(weights_shape);
+ const int weights_per_depth = weights_size / output_depth;
+ CHECK_EQ(weights_size, weights_per_depth * output_depth);
+
+ for (int d = 0; d < output_depth; d++) {
+ float accumulation = 0;
+ for (int i = 0; i < weights_per_depth; i++) {
+ accumulation +=
+ add_scalar_operand * weights_data[d * weights_per_depth + i];
+ }
+ bias_data[d] += accumulation;
+ }
+ } else if (following_op->type == OperatorType::kDepthwiseConv) {
+ const int output_depth =
+ weights_shape.dims(weights_shape.dimensions_count() - 1);
+ const int weights_size = RequiredBufferSizeForShape(weights_shape);
+ const int weights_per_depth = weights_size / output_depth;
+ CHECK_EQ(weights_size, weights_per_depth * output_depth);
+
+ for (int c = 0; c < output_depth; c++) {
+ float accumulation = 0;
+ for (int k = 0; k < weights_per_depth; k++) {
+ accumulation += add_scalar_operand * weights_data[k * output_depth + c];
+ }
+ bias_data[c] += accumulation;
+ }
+ } else {
+ LOG(FATAL) << "Should not get here.";
+ }
+}
+
+void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op,
+ const Operator* mul_or_div_op,
+ int index_of_constant_input) {
+ CHECK(mul_or_div_op->type == OperatorType::kMul ||
+ mul_or_div_op->type == OperatorType::kDiv);
+ CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
+ // If the op is a division, the constant input should be the right hand side.
+ // This should have been checked before this point.
+ CHECK(mul_or_div_op->type != OperatorType::kDiv ||
+ index_of_constant_input == 1);
+ const auto& weights_name = following_op->inputs[1];
+ const auto& bias_name = following_op->inputs[2];
+ auto& weights = model->GetArray(weights_name);
+ DropMinMax(model, weights_name);
+ DropMinMax(model, bias_name);
+ const auto& operand =
+ model->GetArray(mul_or_div_op->inputs[index_of_constant_input]);
+ // We're only supporting the case of a scalar operand. Should have
+ // been checked earlier.
+ CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1);
+
+ const float scalar_operand =
+ operand.GetBuffer<ArrayDataType::kFloat>().data[0];
+
+ float* weights_data =
+ weights.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
+ const int weights_size = RequiredBufferSizeForShape(weights.shape());
+ for (int i = 0; i < weights_size; i++) {
+ if (mul_or_div_op->type == OperatorType::kMul) {
+ weights_data[i] *= scalar_operand;
+ } else if (mul_or_div_op->type == OperatorType::kDiv) {
+ weights_data[i] /= scalar_operand;
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ }
+}
+
+} // namespace
+
+bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ auto* binary_op = binary_it->get();
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv) {
+ return false;
+ }
+
+ CHECK_EQ(binary_op->inputs.size(), 2);
+
+ // We only can fuse an binary when the two operands break down as follows:
+ // 1. One operand is the (variable) output of a typical affine (linear plus
+ // bias)
+ // op of a finite list of possible types: at the moment Conv,
+ // DepthwiseConv and
+ // FullyConnected are supported.
+ // 2. The other operand is a constant param array.
+ const bool is_input_constant[2] = {
+ IsConstantParameterArray(*model, binary_op->inputs[0]),
+ IsConstantParameterArray(*model, binary_op->inputs[1]),
+ };
+ if (!is_input_constant[0] && !is_input_constant[1]) {
+ // Neither input is constant, so nothing we can fuse into a constant.
+ return false;
+ }
+ if (is_input_constant[0] && is_input_constant[1]) {
+ // Both inputs are constants. That's a job for constants
+ // propagation, not for us to handle here.
+ return false;
+ }
+ const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
+ const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
+ CHECK(is_input_constant[index_of_constant_input]);
+ CHECK(!is_input_constant[index_of_variable_input]);
+
+ // For division, we can only fuse if the denominator is constant.
+ if (binary_op->type == OperatorType::kDiv) {
+ if (index_of_constant_input != 1) {
+ AddMessageF("Not fusing %s because the denominator is not constant",
+ LogName(*binary_op));
+ return false;
+ }
+ }
+
+ const auto& operand_shape =
+ model->GetArray(binary_op->inputs[index_of_constant_input]).shape();
+ for (const auto& dim : operand_shape.dims()) {
+ if (dim > 1) {
+ AddMessageF(
+ "Not fusing %s into the following affine op, because we only know "
+ "how to do so when the constant operand is a scalar",
+ LogName(*binary_op));
+ return false;
+ }
+ }
+
+ if (binary_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ AddMessageF("Not fusing %s because it has a fused activation function",
+ LogName(*binary_op));
+ return false;
+ }
+
+ Operator* following_op = GetOpWithInput(*model, binary_op->outputs[0]);
+
+ if (!following_op) {
+ AddMessageF(
+ "Not fusing %s because it is not consumed by exactly one other op",
+ LogName(*binary_op));
+ return false;
+ }
+
+ if (following_op->type != OperatorType::kConv &&
+ following_op->type != OperatorType::kFullyConnected &&
+ following_op->type != OperatorType::kDepthwiseConv) {
+ AddMessageF(
+ "Not fusing %s because the following %s is not of one of the supported "
+ "types",
+ LogName(*binary_op), LogName(*following_op));
+ return false;
+ }
+
+ if (following_op->inputs.size() < 3) {
+ AddMessageF(
+ "Not fusing %s because the following %s does not have a bias vector",
+ LogName(*following_op), LogName(*binary_op));
+ return false;
+ }
+
+ const auto& weights = model->GetArray(following_op->inputs[1]);
+ const auto& bias = model->GetArray(following_op->inputs[2]);
+ if (!weights.buffer || !bias.buffer) {
+ AddMessageF(
+ "Not fusing %s because the following %s has non-constant weights or "
+ "bias arrays",
+ LogName(*binary_op), LogName(*following_op));
+ return false;
+ }
+
+ // Try to fuse the binary params into the following op's params
+ if (binary_op->type == OperatorType::kAdd ||
+ binary_op->type == OperatorType::kSub) {
+ if (following_op->type == OperatorType::kConv) {
+ if (static_cast<ConvOperator*>(following_op)->padding.type !=
+ PaddingType::kValid) {
+ AddMessageF(
+ "Not fusing %s because the following %s does not use VALID padding",
+ LogName(*binary_op), LogName(*following_op));
+ return false;
+ }
+ }
+ if (following_op->type == OperatorType::kDepthwiseConv) {
+ if (static_cast<DepthwiseConvOperator*>(following_op)->padding.type !=
+ PaddingType::kValid) {
+ AddMessageF(
+ "Not fusing %s because the following %s does not use VALID padding",
+ LogName(*binary_op), LogName(*following_op));
+ return false;
+ }
+ }
+ FuseAddOrSubParamsIntoFollowingAffine(model, following_op, binary_op,
+ index_of_constant_input);
+ } else if (binary_op->type == OperatorType::kMul ||
+ binary_op->type == OperatorType::kDiv) {
+ FuseMulOrDivParamsIntoFollowingAffine(model, following_op, binary_op,
+ index_of_constant_input);
+ } else {
+ LOG(FATAL) << "should not get here";
+ }
+
+ AddMessageF("Fusing %s into the following %s", LogName(*binary_op),
+ LogName(*following_op));
+
+ model->arrays.erase(binary_op->outputs[0]);
+ following_op->inputs[0] = binary_op->inputs[index_of_variable_input];
+ const auto& old_constant_param_name =
+ binary_op->inputs[index_of_constant_input];
+ CHECK(IsConstantParameterArray(*model, old_constant_param_name));
+ if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
+ model->arrays.erase(old_constant_param_name);
+ }
+ model->operators.erase(binary_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
new file mode 100644
index 0000000000..8948653ec3
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
@@ -0,0 +1,326 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
+ const Operator* add_or_sub_op,
+ int index_of_constant_input) {
+ CHECK(add_or_sub_op->type == OperatorType::kAdd ||
+ add_or_sub_op->type == OperatorType::kSub);
+ CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
+ if (preceding_op->inputs.size() < 3) {
+ LOG(FATAL) << "Missing bias parameter";
+ }
+ auto& bias = model->GetArray(preceding_op->inputs[2]);
+ bias.minmax = nullptr;
+ const auto& operand =
+ model->GetArray(add_or_sub_op->inputs[index_of_constant_input]);
+
+ const Shape& bias_shape = bias.shape();
+ const Shape& operand_shape = operand.shape();
+ auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>();
+ float* const bias_data = bias_buffer.data.data();
+ const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>();
+ const float* const operand_data = operand_buffer.data.data();
+
+ // TODO(b/62904716): Bias array should become 1-D when padding removed.
+ const int depth = bias_shape.dims(bias_shape.dimensions_count() - 1);
+ CHECK_EQ(depth, operand_shape.dims(operand_shape.dimensions_count() - 1));
+
+ enum class OpType { BiasPlusOperand, BiasMinusOperand, OperandMinusBias };
+
+ const OpType optype = (add_or_sub_op->type == OperatorType::kAdd)
+ ? OpType::BiasPlusOperand
+ : (index_of_constant_input == 1)
+ ? OpType::BiasMinusOperand
+ : OpType::OperandMinusBias;
+
+ for (int i = 0; i < depth; i++) {
+ float& bias_val = bias_data[i];
+ const float operand_val = operand_data[i];
+ if (optype == OpType::BiasPlusOperand) {
+ bias_val += operand_val;
+ } else if (optype == OpType::BiasMinusOperand) {
+ bias_val -= operand_val;
+ } else if (optype == OpType::OperandMinusBias) {
+ bias_val = operand_val - bias_val;
+ } else {
+ LOG(FATAL) << "Should not get here.";
+ }
+ }
+}
+
+void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
+ const Operator* mul_or_div_op,
+ int index_of_constant_input) {
+ CHECK(mul_or_div_op->type == OperatorType::kMul ||
+ mul_or_div_op->type == OperatorType::kDiv);
+ CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
+ // If the op is a division, the constant input should be the right hand side.
+ // This should have been checked before this point.
+ CHECK(mul_or_div_op->type != OperatorType::kDiv ||
+ index_of_constant_input == 1);
+ if (preceding_op->inputs.size() < 3) {
+ LOG(FATAL) << "Missing bias parameter";
+ }
+ const auto& weights_name = preceding_op->inputs[1];
+ const auto& bias_name = preceding_op->inputs[2];
+ auto& weights = model->GetArray(weights_name);
+ DropMinMax(model, weights_name);
+ auto& bias = model->GetArray(bias_name);
+ DropMinMax(model, bias_name);
+ const auto& operand =
+ model->GetArray(mul_or_div_op->inputs[index_of_constant_input]);
+
+ const Shape& weights_shape = weights.shape();
+ const Shape& bias_shape = bias.shape();
+ const Shape& operand_shape = operand.shape();
+ auto& weights_buffer = weights.GetMutableBuffer<ArrayDataType::kFloat>();
+ float* const weights_data = weights_buffer.data.data();
+ auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>();
+ float* const bias_data = bias_buffer.data.data();
+ const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>();
+ const float* const operand_data = operand_buffer.data.data();
+
+ // We support broadcasting the operand along the depth dimension,
+ // when the operand's depth is 1.
+ int operand_channel_increment = 0;
+ if (operand_shape.dimensions_count() >= 1 &&
+ operand_shape.dims(operand_shape.dimensions_count() - 1) ==
+ bias_shape.dims(bias_shape.dimensions_count() - 1)) {
+ operand_channel_increment = 1;
+ } else if (operand_shape.dimensions_count() == 0 ||
+ operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) {
+ operand_channel_increment = 0;
+ } else {
+ LOG(FATAL) << "Operand shape mismatch.";
+ }
+
+ int output_depth;
+
+ if (preceding_op->type == OperatorType::kConv ||
+ preceding_op->type == OperatorType::kFullyConnected) {
+ output_depth = weights_shape.dims(0);
+ } else if (preceding_op->type == OperatorType::kDepthwiseConv) {
+ output_depth = weights_shape.dims(weights_shape.dimensions_count() - 1);
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+
+ const int weights_size = RequiredBufferSizeForShape(weights_shape);
+ const int weights_per_depth = weights_size / output_depth;
+ CHECK_EQ(weights_size, weights_per_depth * output_depth);
+
+ int operand_channel = 0;
+ for (int c = 0; c < output_depth; c++) {
+ if (mul_or_div_op->type == OperatorType::kMul) {
+ bias_data[c] *= operand_data[operand_channel];
+ } else if (mul_or_div_op->type == OperatorType::kDiv) {
+ bias_data[c] /= operand_data[operand_channel];
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ if (preceding_op->type == OperatorType::kConv ||
+ preceding_op->type == OperatorType::kFullyConnected) {
+ for (int i = 0; i < weights_per_depth; i++) {
+ if (mul_or_div_op->type == OperatorType::kMul) {
+ weights_data[c * weights_per_depth + i] *=
+ operand_data[operand_channel];
+ } else if (mul_or_div_op->type == OperatorType::kDiv) {
+ weights_data[c * weights_per_depth + i] /=
+ operand_data[operand_channel];
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ }
+ } else if (preceding_op->type == OperatorType::kDepthwiseConv) {
+ for (int k = 0; k < weights_per_depth; k++) {
+ if (mul_or_div_op->type == OperatorType::kMul) {
+ weights_data[k * output_depth + c] *= operand_data[operand_channel];
+ } else if (mul_or_div_op->type == OperatorType::kDiv) {
+ weights_data[k * output_depth + c] /= operand_data[operand_channel];
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ }
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ operand_channel += operand_channel_increment;
+ }
+}
+} // namespace
+
+bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ const auto* binary_op = binary_it->get();
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv) {
+ return false;
+ }
+
+ CHECK_EQ(binary_op->inputs.size(), 2);
+
+ // We only can fuse an binary when the two operands break down as follows:
+ // 1. One operand is the (variable) output of a typical affine (linear plus
+ // bias)
+ // op of a finite list of possible types: at the moment Conv,
+ // DepthwiseConv and
+ // FullyConnected are supported.
+ // 2. The other operand is a constant param array.
+ const bool is_input_constant[2] = {
+ IsConstantParameterArray(*model, binary_op->inputs[0]),
+ IsConstantParameterArray(*model, binary_op->inputs[1]),
+ };
+ if (!is_input_constant[0] && !is_input_constant[1]) {
+ // Neither input is constant, so nothing we can fuse into a constant.
+ return false;
+ }
+ if (is_input_constant[0] && is_input_constant[1]) {
+ // Both inputs are constants. That's a job for constants
+ // propagation, not for us to handle here.
+ return false;
+ }
+ const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
+ const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
+ CHECK(is_input_constant[index_of_constant_input]);
+ CHECK(!is_input_constant[index_of_variable_input]);
+
+ // For division, we can only fuse if the denominator is constant.
+ if (binary_op->type == OperatorType::kDiv) {
+ if (index_of_constant_input != 1) {
+ AddMessageF("Not fusing %s because the denominator is not constant",
+ LogName(*binary_op));
+ return false;
+ }
+ }
+
+ Operator* preceding_op =
+ GetOpWithOutput(*model, binary_op->inputs[index_of_variable_input]);
+ if (!preceding_op) {
+ AddMessageF("Not fusing %s because it is not the output of another op",
+ LogName(*binary_op));
+ return false;
+ }
+
+ for (const string& output_array : model->flags.output_arrays()) {
+ if (preceding_op->outputs[0] == output_array) {
+ return false;
+ }
+ }
+
+ if (preceding_op->type != OperatorType::kConv &&
+ preceding_op->type != OperatorType::kFullyConnected &&
+ preceding_op->type != OperatorType::kDepthwiseConv) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s is not of one of the supported "
+ "types",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+
+ if (preceding_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s has a fused activation "
+ "function",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+
+ if (preceding_op->inputs.size() < 3) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s does not have a bias vector",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+
+ const auto& weights = model->GetArray(preceding_op->inputs[1]);
+ const auto& bias = model->GetArray(preceding_op->inputs[2]);
+ if (binary_op->type == OperatorType::kAdd ||
+ binary_op->type == OperatorType::kSub) {
+ if (!bias.buffer) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s has a non-constant bias "
+ "array",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+ } else {
+ if (!weights.buffer || !bias.buffer) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s has non-constant weights or "
+ "bias arrays",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+ }
+
+ int count_ops_consuming_output =
+ CountOpsWithInput(*model, preceding_op->outputs[0]);
+ DCHECK_GE(count_ops_consuming_output, 1);
+ if (count_ops_consuming_output > 1) {
+ AddMessageF(
+ "Not fusing %s because the output of the preceding %s is consumed by "
+ "another op",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+
+ AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op),
+ LogName(*preceding_op));
+
+ if (binary_op->type == OperatorType::kAdd ||
+ binary_op->type == OperatorType::kSub) {
+ FuseAddOrSubParamsIntoPrecedingAffine(model, preceding_op, binary_op,
+ index_of_constant_input);
+ } else if (binary_op->type == OperatorType::kMul ||
+ binary_op->type == OperatorType::kDiv) {
+ FuseMulOrDivParamsIntoPrecedingAffine(model, preceding_op, binary_op,
+ index_of_constant_input);
+ } else {
+ LOG(FATAL) << "should not get here";
+ }
+
+ model->arrays.erase(preceding_op->outputs[0]);
+ preceding_op->outputs[0] = binary_op->outputs[0];
+ preceding_op->fused_activation_function =
+ binary_op->fused_activation_function;
+ const auto& old_constant_param_name =
+ binary_op->inputs[index_of_constant_input];
+ CHECK(IsConstantParameterArray(*model, old_constant_param_name));
+ if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
+ model->arrays.erase(old_constant_param_name);
+ }
+ model->operators.erase(binary_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
new file mode 100644
index 0000000000..323fec6cf8
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
@@ -0,0 +1,108 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void PrintModelStats(const string& label, const Model& model) {
+ int quantized_arrays = 0;
+ for (const auto& array : model.arrays) {
+ if (array.second->quantization_params) {
+ quantized_arrays++;
+ }
+ }
+ LOG(INFO) << label << ": " << model.operators.size() << " operators, "
+ << model.arrays.size() << " arrays (" << quantized_arrays
+ << " quantized)";
+}
+
+bool GraphTransformationsPass(int increment, Model* model,
+ const GraphTransformationsSet& transformations) {
+ CHECK(increment == 1 || increment == -1);
+ bool changed = false;
+ CHECK(!model->operators.empty());
+ int op_index = increment == 1 ? 0 : model->operators.size() - 1;
+ while (true) {
+ bool changed_now = false;
+ // Loop over all transformations at the current position in the graph.
+ for (const auto& transformation : transformations) {
+ CHECK(!changed_now);
+ CHECK(transformation->Messages().empty());
+ changed_now = transformation->Run(model, op_index);
+ if (changed_now) {
+ DumpGraphvizVideoFrame(*model);
+ CHECK(!model->operators.empty());
+ op_index = std::min<int>(op_index, model->operators.size() - 1);
+ // Uncomment for debugging
+ // CheckInvariants(*model);
+ }
+ const char* made_a_change_msg =
+ changed_now ? "made a change" : "did NOT make a change";
+ const int log_level =
+ changed_now ? kLogLevelModelChanged : kLogLevelModelUnchanged;
+ for (const string& message : transformation->Messages()) {
+ VLOG(log_level) << transformation->Name() << " " << made_a_change_msg
+ << " at op_index=" << op_index << "/"
+ << model->operators.size() - 1 << ": " << message;
+ }
+ transformation->ClearMessages();
+ if (changed_now) {
+ break;
+ }
+ }
+ if (changed_now) {
+ changed = true;
+ } else {
+ const int op_index_last =
+ increment == 1 ? model->operators.size() - 1 : 0;
+ if (op_index == op_index_last) {
+ break;
+ }
+ op_index += increment;
+ }
+ }
+ return changed;
+}
+
+} // namespace
+
+void RunGraphTransformations(Model* model, const string& msg,
+ const GraphTransformationsSet& transformations) {
+ PrintModelStats(toco::port::StringF("Before %s", msg), *model);
+ int pass_index = 0;
+ while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model,
+ transformations)) {
+ pass_index++;
+ const auto& label =
+ toco::port::StringF("After %s pass %d", msg, pass_index);
+ PrintModelStats(label, *model);
+ CheckInvariants(*model);
+ }
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
new file mode 100644
index 0000000000..2cc24ff361
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -0,0 +1,186 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
+
+#include <cstddef>
+#include <initializer_list>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+
+namespace toco {
+
+class GraphTransformation {
+ public:
+ virtual bool Run(Model* model, std::size_t op_index) = 0;
+ virtual const char* Name() const = 0;
+ virtual ~GraphTransformation() {}
+ // Returns the list of messages that this graph transformation
+ // generated since ClearMessages() was called.
+ const std::vector<string>& Messages() const { return messages_; }
+ // Clears the list of messages; should be called after every
+ // run of this graph transformation.
+ void ClearMessages() { return messages_.clear(); }
+ // Adds a message; normally only called by the graph transformation
+ // itself during its run (this function could be protected).
+ template <typename... Args>
+ void AddMessageF(const char* format, const Args&... args) {
+ return messages_.push_back(toco::port::StringF(format, args...));
+ }
+
+ protected:
+ GraphTransformation() {}
+
+ // List of messages generated by this graph transformation.
+ std::vector<string> messages_;
+
+ private:
+ GraphTransformation(const GraphTransformation& other) = delete;
+ GraphTransformation(const GraphTransformation&& other) = delete;
+};
+
+class GraphTransformationsSet {
+ public:
+ // The choice of a container with fully-specified iteration order
+ // ensures that graph transformations are always run in the same order,
+ // which avoids having toco randomly fail or produce different results
+ // depending on the toolchain. Ideally success/results should be independent
+ // of the order in which graph transformations are run, but that's
+ // unfortunately not currently guaranteed to be the case.
+ using TransformationsContainer =
+ std::vector<std::unique_ptr<GraphTransformation>>;
+
+ GraphTransformationsSet() {}
+ GraphTransformationsSet(
+ const std::initializer_list<GraphTransformation*> transformations) {
+ for (GraphTransformation* t : transformations) {
+ Add(t);
+ }
+ }
+ void Add(GraphTransformation* transformation) {
+ const string& name = transformation->Name();
+ CHECK(!names_.count(name));
+ names_.insert(name);
+ transformations_.emplace_back(transformation);
+ }
+ TransformationsContainer::const_iterator begin() const {
+ return transformations_.begin();
+ }
+ TransformationsContainer::const_iterator end() const {
+ return transformations_.end();
+ }
+ bool empty() const { return transformations_.empty(); }
+
+ private:
+ GraphTransformationsSet(const GraphTransformationsSet& other) = delete;
+ GraphTransformationsSet(const GraphTransformationsSet&& other) = delete;
+ std::vector<std::unique_ptr<GraphTransformation>> transformations_;
+ // Names of transformations in the set. Only used to guard against dupes.
+ std::unordered_set<string> names_;
+};
+
+// Run the given list of graph transformations on the model.
+// The message is only for logging purposes.
+// The transformations is a rvalue reference, indicating that
+// nothing else will use these pointers. The user is supposed to
+// construct GraphTransformation objects by using 'new', pass us
+// the resulting raw pointers, and this RunGraphTransformations
+// takes care of delete'ing these pointers.
+void RunGraphTransformations(Model* model, const string& message,
+ const GraphTransformationsSet& transformations);
+
+#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
+ class GTName : public GraphTransformation { \
+ public: \
+ bool Run(Model* model, std::size_t op_index) override; \
+ const char* Name() const { return #GTName; } \
+ };
+
+// List of all graph transformations
+DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
+DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors)
+DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions)
+DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine)
+DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine)
+DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization)
+DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool)
+DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell)
+DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
+DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
+DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes)
+DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes)
+DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
+DECLARE_GRAPH_TRANSFORMATION(Quantize)
+DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenation)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenationInput)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedActivationFunc)
+DECLARE_GRAPH_TRANSFORMATION(RemoveUnusedOp)
+DECLARE_GRAPH_TRANSFORMATION(ResolveBatchNormalization)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantBinaryOperator)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator)
+DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays)
+DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays)
+DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax)
+DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSqueeze)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation)
+DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant)
+DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions)
+DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTensorFlowShape)
+DECLARE_GRAPH_TRANSFORMATION(Dequantize)
+
+class ResolveReshapeAttributes : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override { return "ResolveReshapeAttributes"; }
+};
+
+class RemoveTrivialReshape : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override { return "RemoveTrivialReshape"; }
+ bool treat_expand_dims_as_trivial() const {
+ return treat_expand_dims_as_trivial_;
+ }
+ void set_treat_expand_dims_as_trivial(bool val) {
+ treat_expand_dims_as_trivial_ = val;
+ }
+
+ private:
+ bool treat_expand_dims_as_trivial_ = false;
+};
+
+#undef DECLARE_GRAPH_TRANSFORMATION
+
+} // end namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
new file mode 100644
index 0000000000..d44b5dc7b0
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -0,0 +1,229 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool HardcodeMinMaxForIm2colArray(Model* model, Operator* op) {
+ if (op->outputs.size() != 2) {
+ return false;
+ }
+ auto& im2col_array = model->GetArray(op->outputs[1]);
+ if (im2col_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ const auto& input_minmax = input_array.GetMinMax();
+ CHECK(!im2col_array.minmax);
+ auto& im2col_minmax = im2col_array.GetOrCreateMinMax();
+ im2col_minmax.min = input_minmax.min;
+ im2col_minmax.max = input_minmax.max;
+ return true;
+}
+
+bool HardcodeMinMaxForL2Normalization(Model* model, Operator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ const auto& input_minmax = input_array.GetMinMax();
+ CHECK(!output_array.minmax);
+ auto& output_minmax = output_array.GetOrCreateMinMax();
+ output_minmax.min = input_minmax.min >= 0. ? 0. : -1.;
+ output_minmax.max = input_minmax.max <= 0. ? 0. : 1.;
+ return true;
+}
+
+bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) {
+ // Do not early return if the output already has min/max:
+ // we may still need to adjust the inputs min/max.
+ bool has_minmax = false;
+ double overall_min = std::numeric_limits<double>::infinity();
+ double overall_max = -std::numeric_limits<double>::infinity();
+ for (const auto& input : op->inputs) {
+ if (model->GetArray(input).minmax) {
+ has_minmax = true;
+ const auto* minmax = model->GetArray(input).minmax.get();
+ if (minmax) {
+ overall_min = std::min(overall_min, minmax->min);
+ overall_max = std::max(overall_max, minmax->max);
+ }
+ }
+ }
+ auto& output = model->GetArray(op->outputs[0]);
+ if (output.minmax) {
+ has_minmax = true;
+ const auto* minmax = model->GetArray(op->outputs[0]).minmax.get();
+ if (minmax) {
+ overall_min = std::min(overall_min, minmax->min);
+ overall_max = std::max(overall_max, minmax->max);
+ }
+ }
+ if (!has_minmax) {
+ return false;
+ }
+ MinMax overall_minmax;
+ overall_minmax.min = overall_min;
+ overall_minmax.max = overall_max;
+ bool changed = false;
+ for (const auto& input : op->inputs) {
+ auto& array = model->GetArray(input);
+ if (!array.minmax) {
+ changed = true;
+ } else if (!(overall_minmax == array.GetMinMax())) {
+ changed = true;
+ LOG(WARNING)
+ << "Tweaking the MinMax of array " << input << ", which is "
+ << "an input to " << LogName(*op) << ", because we want all inputs "
+ << "and outputs of a Concatenation operator to have the same MinMax "
+ << "so that it can be implemented as a pure byte-copy, no "
+ "arithmetic.";
+ }
+ array.GetOrCreateMinMax() = overall_minmax;
+ }
+ if (!output.minmax) {
+ changed = true;
+ } else if (!(overall_minmax == output.GetMinMax())) {
+ changed = true;
+ LOG(WARNING)
+ << "Tweaking the MinMax of the output array of " << LogName(*op)
+ << ", because we want all inputs "
+ << "and outputs of a Concatenation operator to have the same MinMax "
+ << "so that it can be implemented as a pure byte-copy, no arithmetic.";
+ }
+ output.GetOrCreateMinMax() = overall_minmax;
+
+ return changed;
+}
+
+// The output of average or max pooling is within the same range as its input.
+bool HardcodeMinMaxForAverageOrMaxPool(Model* model, Operator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ const auto& input_minmax = input_array.GetMinMax();
+ CHECK(!output_array.minmax);
+ auto& output_minmax = output_array.GetOrCreateMinMax();
+ output_minmax.min = std::min(input_minmax.min, 0.);
+ output_minmax.max = std::max(input_minmax.max, 0.);
+ return true;
+}
+
+bool HardcodeMinMaxForReshape(Model* model, Operator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ const auto& input_minmax = input_array.GetMinMax();
+ CHECK(!output_array.minmax);
+ auto& output_minmax = output_array.GetOrCreateMinMax();
+ output_minmax.min = input_minmax.min;
+ output_minmax.max = input_minmax.max;
+ return true;
+}
+
+bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min,
+ double max) {
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ CHECK(!output_array.minmax);
+ auto& output_minmax = output_array.GetOrCreateMinMax();
+ output_minmax.min = min;
+ output_minmax.max = max;
+ return true;
+}
+} // namespace
+
+bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+ bool changed = false;
+ switch (op->type) {
+ case OperatorType::kConv:
+ changed = HardcodeMinMaxForIm2colArray(model, op);
+ break;
+
+ case OperatorType::kL2Normalization:
+ changed = HardcodeMinMaxForL2Normalization(model, op);
+ break;
+
+ case OperatorType::kConcatenation:
+ changed = HardcodeMinMaxForConcatenation(model, op);
+ break;
+
+ case OperatorType::kAveragePool:
+ case OperatorType::kMaxPool:
+ changed = HardcodeMinMaxForAverageOrMaxPool(model, op);
+ break;
+
+ case OperatorType::kTensorFlowReshape:
+ changed = HardcodeMinMaxForReshape(model, op);
+ break;
+
+ case OperatorType::kLogistic:
+ // We hardcode quantization_params to: zero_point=0, scale=1/256.
+ // This choice of minmax is the one that is equivalent to that.
+ changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.);
+ break;
+
+ case OperatorType::kSoftmax:
+ // We hardcode quantization_params to: zero_point=0, scale=1/256.
+ // This choice of minmax is the one that is equivalent to that.
+ changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.);
+ break;
+
+ default:
+ break;
+ }
+ if (changed) {
+ AddMessageF("Hardcoded min-max through %s", LogName(*op));
+ }
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
new file mode 100644
index 0000000000..01b75e37c6
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
@@ -0,0 +1,170 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cmath>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
+ Model* model, const Operator* op) {
+ auto it = model->operators.begin();
+ for (; it != model->operators.end(); ++it) {
+ if (it->get() == op) {
+ break;
+ }
+ }
+ return it;
+}
+} // namespace
+
+bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
+ const auto div_it = model->operators.begin() + op_index;
+ const auto* div_or_mul_op = div_it->get();
+ OperatorType expected_op_type_producing_div_or_mul_input;
+ if (div_or_mul_op->type == OperatorType::kDiv) {
+ expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt;
+ } else if (div_or_mul_op->type == OperatorType::kMul) {
+ expected_op_type_producing_div_or_mul_input =
+ OperatorType::kTensorFlowRsqrt;
+ } else {
+ return false;
+ }
+ CHECK_EQ(div_or_mul_op->inputs.size(), 2);
+ Operator* op_producing_div_or_mul_input[2] = {
+ GetOpWithOutput(*model, div_or_mul_op->inputs[0]),
+ GetOpWithOutput(*model, div_or_mul_op->inputs[1]),
+ };
+ if (!op_producing_div_or_mul_input[1] ||
+ op_producing_div_or_mul_input[1]->type !=
+ expected_op_type_producing_div_or_mul_input) {
+ return false;
+ }
+ Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1];
+ CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1);
+ Operator* op_producing_sqrt_or_rsqrt_input =
+ GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]);
+ if (!op_producing_sqrt_or_rsqrt_input) {
+ return false;
+ }
+
+ // There may be an Add or a Maximum here, adding or clamping to a "small"
+ // constant scalar.
+ // Reported bug: b/29395854
+ Operator* add_op = nullptr;
+ Operator* op_producing_add_input = nullptr;
+ if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd ||
+ op_producing_sqrt_or_rsqrt_input->type ==
+ OperatorType::kTensorFlowMaximum) {
+ add_op = op_producing_sqrt_or_rsqrt_input;
+ bool add_can_be_removed = false;
+ CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2);
+ for (int i = 0; i < 2; i++) {
+ const auto& input_array =
+ model->GetArray(op_producing_sqrt_or_rsqrt_input->inputs[i]);
+ if (!input_array.buffer) {
+ continue;
+ }
+ if (input_array.buffer->type != ArrayDataType::kFloat) {
+ continue;
+ }
+ if (RequiredBufferSizeForShape(input_array.shape()) != 1) {
+ continue;
+ }
+ const auto& input_float_data =
+ input_array.GetBuffer<ArrayDataType::kFloat>().data;
+ if (std::abs(input_float_data[0]) > 1e-3f) {
+ continue;
+ }
+ add_can_be_removed = true;
+ op_producing_add_input = GetOpWithOutput(*model, add_op->inputs[1 - i]);
+ break;
+ }
+ if (!add_can_be_removed) {
+ AddMessageF(
+ "Giving up trying to identify L2Normalization subgraph "
+ " because the operator producing the input to the square root, %s,"
+ ", does not match the expected pattern",
+ LogName(*op_producing_sqrt_or_rsqrt_input));
+ return false;
+ }
+ }
+
+ Operator* sum_op =
+ add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input;
+ if (sum_op->type != OperatorType::kTensorFlowSum) {
+ AddMessageF(
+ "Giving up trying to identify L2Normalization subgraph: "
+ "expected Sum op, got %s",
+ LogName(*sum_op));
+ return false;
+ }
+
+ Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]);
+ if (square_op->type != OperatorType::kTensorFlowSquare) {
+ AddMessageF(
+ "Giving up trying to identify L2Normalization subgraph: "
+ "expected Square op, got %s",
+ LogName(*square_op));
+ return false;
+ }
+
+ CHECK_EQ(square_op->inputs.size(), 1);
+
+ if (square_op->inputs[0] != div_or_mul_op->inputs[0]) {
+ AddMessageF(
+ "Giving up trying to identify L2Normalization subgraph: %s does not "
+ "take the same input as the Mul/Div node",
+ LogName(*square_op));
+ return false;
+ }
+
+ // Create and emplace the new L2Normalization
+ auto* l2norm_op = new L2NormalizationOperator;
+ l2norm_op->inputs = {div_or_mul_op->inputs[0]};
+ l2norm_op->outputs = div_or_mul_op->outputs;
+ model->operators.emplace(div_it, l2norm_op);
+
+ AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op));
+
+ // Erase the subgraph that is now replaced by L2Normalization
+ model->operators.erase(FindOperator(model, square_op));
+ model->arrays.erase(sum_op->inputs[0]);
+ if (sum_op->inputs.size() > 1) {
+ model->arrays.erase(sum_op->inputs[1]);
+ }
+ model->operators.erase(FindOperator(model, sum_op));
+ if (add_op) {
+ model->arrays.erase(add_op->inputs[0]);
+ model->arrays.erase(add_op->inputs[1]);
+ model->operators.erase(FindOperator(model, add_op));
+ }
+ model->arrays.erase(sqrt_or_rsqrt_op->inputs[0]);
+ model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op));
+ model->arrays.erase(div_or_mul_op->inputs[1]);
+ model->operators.erase(FindOperator(model, div_or_mul_op));
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
new file mode 100644
index 0000000000..1865416fc2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
@@ -0,0 +1,106 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
+ Model* model, const Operator* op) {
+ auto it = model->operators.begin();
+ for (; it != model->operators.end(); ++it) {
+ if (it->get() == op) {
+ break;
+ }
+ }
+ return it;
+}
+} // namespace
+
+bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
+ const auto sqrt_it = model->operators.begin() + op_index;
+ const auto* sqrt_op = sqrt_it->get();
+ if (sqrt_op->type != OperatorType::kTensorFlowSqrt) {
+ return false;
+ }
+
+ CHECK_EQ(sqrt_op->inputs.size(), 1);
+ CHECK_EQ(sqrt_op->outputs.size(), 1);
+
+ const AveragePoolOperator* avpool_op;
+ const Operator* square_op;
+
+ Operator* prev_to_sqrt_op = GetOpWithOutput(*model, sqrt_op->inputs[0]);
+ if (prev_to_sqrt_op->type != OperatorType::kAveragePool) {
+ AddMessageF(
+ "Giving up trying to identify L2Pool subgraph: "
+ "expected AveragePool op, got %s",
+ LogName(*prev_to_sqrt_op));
+ return false;
+ }
+
+ avpool_op = static_cast<const AveragePoolOperator*>(prev_to_sqrt_op);
+ CHECK_EQ(avpool_op->inputs.size(), 1);
+
+ square_op = GetOpWithOutput(*model, avpool_op->inputs[0]);
+ CHECK_EQ(square_op->inputs.size(), 1);
+ if (square_op->type != OperatorType::kTensorFlowSquare) {
+ AddMessageF(
+ "Giving up trying to identify L2Pool subgraph: "
+ "expected Square op, got %s",
+ LogName(*square_op));
+ return false;
+ }
+
+ // Create and emplace L2Pool node.
+ auto* l2pool_op = new L2PoolOperator;
+
+ l2pool_op->inputs = {square_op->inputs[0]};
+ l2pool_op->outputs = sqrt_op->outputs;
+
+ l2pool_op->padding.type = avpool_op->padding.type;
+ // Note that we do not setup avpool_op->padding.fixed here. This is done by
+ // the PropagateFixedSizes graph transformation.
+
+ l2pool_op->stride_height = avpool_op->stride_height;
+ l2pool_op->stride_width = avpool_op->stride_width;
+ l2pool_op->kheight = avpool_op->kheight;
+ l2pool_op->kwidth = avpool_op->kwidth;
+ model->operators.emplace(sqrt_it, l2pool_op);
+
+ AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op));
+
+ // Erase intermediate arrays, keeping input to square op.
+ model->arrays.erase(avpool_op->inputs[0]);
+ model->arrays.erase(sqrt_op->inputs[0]);
+
+ // Erase three operators being replaced.
+ model->operators.erase(FindOperator(model, square_op));
+ model->operators.erase(FindOperator(model, avpool_op));
+ model->operators.erase(FindOperator(model, sqrt_op));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
new file mode 100644
index 0000000000..082820fddc
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
@@ -0,0 +1,396 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
+ Model* model, const Operator& op) {
+ auto it = model->operators.begin();
+ for (; it != model->operators.end(); ++it) {
+ if (it->get() == &op) {
+ break;
+ }
+ }
+ return it;
+}
+
+bool GetStateArrayForBackEdge(const Model& model,
+ const string& back_edge_source_array,
+ string* state_array = nullptr) {
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ if (back_edge_source_array == rnn_state.back_edge_source_array()) {
+ // Found LSTM cell output
+ if (state_array) {
+ *state_array = rnn_state.state_array();
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
+// Returns true if the given operator has exactly 1 input, and is connected to
+// the given op_type.
+// We use kNone to indicate an input unattached to an operator output. Usually
+// these are the static input arrays.
+bool MatchOperatorInputs(const Operator& op, const Model& model,
+ OperatorType op_type, Operator** connected_op) {
+ // Check for required number of inputs
+ if (op.inputs.size() != 1) {
+ return false;
+ }
+
+ // Check if first input is disconnected/connected to an operator
+ Operator* x = GetOpWithOutput(model, op.inputs[0]);
+ if ((op_type == OperatorType::kNone) && (x != nullptr)) {
+ return false;
+ }
+ if ((op_type != OperatorType::kNone) && (x == nullptr)) {
+ return false;
+ }
+
+ // Check that first operator, if connected, is of correct type
+ if ((x != nullptr) && (x->type != op_type)) {
+ return false;
+ }
+
+ // Successfully matched. Optionally return matching input operators.
+ if (connected_op) {
+ *connected_op = x;
+ }
+
+ return true;
+}
+
+// Returns true if the given operator has exactly 2 inputs, which are connected
+// to the given op_types.
+// We use kNone to indicate an input unattached to an operator output. Usually
+// these are the static input arrays.
+bool MatchOperatorInputs(const Operator& op, const Model& model,
+ OperatorType a_op_type, Operator** a_op,
+ OperatorType b_op_type, Operator** b_op) {
+ // Check for required number of inputs
+ if (op.inputs.size() != 2) {
+ return false;
+ }
+
+ // Check if first input is disconnected/connected to an operator
+ Operator* x = GetOpWithOutput(model, op.inputs[0]);
+ if ((a_op_type == OperatorType::kNone) && (x != nullptr)) {
+ return false;
+ }
+ if ((a_op_type != OperatorType::kNone) && (x == nullptr)) {
+ return false;
+ }
+
+ // Check that first operator, if connected, is of correct type
+ if ((x != nullptr) && (x->type != a_op_type)) {
+ return false;
+ }
+
+ // Check if second input is disconnected/connected to an operator
+ Operator* y = GetOpWithOutput(model, op.inputs[1]);
+ if ((b_op_type == OperatorType::kNone) && (y != nullptr)) {
+ return false;
+ }
+ if ((b_op_type != OperatorType::kNone) && (y == nullptr)) {
+ return false;
+ }
+
+ // Check that second operator, if connected, is of correct type
+ if ((y != nullptr) && (y->type != b_op_type)) {
+ return false;
+ }
+
+ // Successfully matched. Optionally return matching input operators.
+ if (a_op != nullptr) {
+ *a_op = x;
+ }
+ if (b_op != nullptr) {
+ *b_op = y;
+ }
+ return true;
+}
+
+// Returns true if the given operator has exactly 3 inputs, which are connected
+// to the given op_types.
+// We use kNone to indicate an input unattached to an operator output. Usually
+// these are the static input arrays.
+bool MatchOperatorInputs(const Operator& op, const Model& model,
+ OperatorType a_op_type, Operator** a_op,
+ OperatorType b_op_type, Operator** b_op,
+ OperatorType c_op_type, Operator** c_op) {
+ // Check for required number of inputs
+ if (op.inputs.size() != 3) {
+ return false;
+ }
+
+ // Check if first input is disconnected/connected to an operator
+ Operator* x = GetOpWithOutput(model, op.inputs[0]);
+ if ((a_op_type == OperatorType::kNone) && (x != nullptr)) {
+ return false;
+ }
+ if ((a_op_type != OperatorType::kNone) && (x == nullptr)) {
+ return false;
+ }
+
+ // Check that first operator, if connected, is of correct type
+ if ((x != nullptr) && (x->type != a_op_type)) {
+ return false;
+ }
+
+ // Check if second input is disconnected/connected to an operator
+ Operator* y = GetOpWithOutput(model, op.inputs[1]);
+ if ((b_op_type == OperatorType::kNone) && (y != nullptr)) {
+ return false;
+ }
+ if ((b_op_type != OperatorType::kNone) && (y == nullptr)) {
+ return false;
+ }
+
+ // Check that second operator, if connected, is of correct type
+ if ((y != nullptr) && (y->type != b_op_type)) {
+ return false;
+ }
+
+ // Check if third input is disconnected/connected to an operator
+ Operator* z = GetOpWithOutput(model, op.inputs[2]);
+ if ((c_op_type == OperatorType::kNone) && (z != nullptr)) {
+ return false;
+ }
+ if ((c_op_type != OperatorType::kNone) && (z == nullptr)) {
+ return false;
+ }
+
+ // Check that third operator, if connected, is of correct type
+ if ((z != nullptr) && (z->type != c_op_type)) {
+ return false;
+ }
+
+ // Successfully matched. Optionally return matching input operators.
+ if (a_op != nullptr) {
+ *a_op = x;
+ }
+ if (b_op != nullptr) {
+ *b_op = y;
+ }
+ if (c_op != nullptr) {
+ *c_op = z;
+ }
+ return true;
+}
+
+absl::string_view FindLongestCommonPrefix(absl::string_view a,
+ absl::string_view b) {
+ if (a.empty() || b.empty()) return absl::string_view();
+
+ const char* pa = a.data();
+ const char* pb = b.data();
+ size_t count = 0;
+ const ssize_t limit = std::min(a.size(), b.size());
+ while (count < limit && *pa == *pb) {
+ ++pa;
+ ++pb;
+ ++count;
+ }
+
+ return absl::string_view(a.data(), count);
+}
+
+} // namespace
+
+bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
+ // This LSTM cell identification method is not invariant to commutation of
+ // commutative operator inputs. For example, if input[0] and input[1] of the
+ // final output multiplication were swapped, this method would not identify it
+ // as an LSTM cell. This is OK in most cases, because
+ // tf.rnn.contrib.BasicLSTMCell always generates LSTM cells the same way.
+
+ // Final output multiply
+ auto op_it = model->operators.begin() + op_index;
+ Operator* final_output_mul = op_it->get();
+ if (final_output_mul->type != OperatorType::kMul) {
+ return false;
+ }
+ Operator *state_output_tanh, *fc_output_sig;
+ if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh,
+ &state_output_tanh, OperatorType::kLogistic,
+ &fc_output_sig)) {
+ return false;
+ }
+
+ // State output TanH
+ // (We don't count an operator as ID'd until we verify it has the correct
+ // operator types feeding into it.)
+ Operator* state_combine_add;
+ if (!MatchOperatorInputs(*state_output_tanh, *model, OperatorType::kAdd,
+ &state_combine_add)) {
+ return false;
+ }
+ string prev_state;
+ if (!GetStateArrayForBackEdge(*model, state_output_tanh->inputs[0],
+ &prev_state)) {
+ return false;
+ }
+
+ // State forget & remember addition
+ Operator *state_forget_mul, *state_remember_mul;
+ if (!MatchOperatorInputs(*state_combine_add, *model, OperatorType::kMul,
+ &state_forget_mul, OperatorType::kMul,
+ &state_remember_mul)) {
+ return false;
+ }
+ if (state_forget_mul->inputs[0] != prev_state) {
+ return false;
+ }
+
+ // State forget gate
+ Operator* state_forget_sig;
+ if (!MatchOperatorInputs(*state_forget_mul, *model, OperatorType::kNone,
+ nullptr, OperatorType::kLogistic,
+ &state_forget_sig)) {
+ return false;
+ }
+
+ // State remember gate
+ Operator *state_remember_sig, *state_info_tanh;
+ if (!MatchOperatorInputs(*state_remember_mul, *model, OperatorType::kLogistic,
+ &state_remember_sig, OperatorType::kTanh,
+ &state_info_tanh)) {
+ return false;
+ }
+
+ // State remember "information" activation function
+ Operator* fc_output_split;
+ if (!MatchOperatorInputs(*state_info_tanh, *model,
+ OperatorType::kTensorFlowSplit, &fc_output_split)) {
+ return false;
+ }
+ // State remember gate activation function
+ Operator* tmp;
+ if (!MatchOperatorInputs(*state_remember_sig, *model,
+ OperatorType::kTensorFlowSplit, &tmp) ||
+ (tmp != fc_output_split)) {
+ return false;
+ }
+ // State forget gate activation function
+ if (!MatchOperatorInputs(*state_forget_sig, *model,
+ OperatorType::kTensorFlowSplit, &tmp) ||
+ (tmp != fc_output_split)) {
+ return false;
+ }
+ // Fully connected output activation function
+ if (!MatchOperatorInputs(*fc_output_sig, *model,
+ OperatorType::kTensorFlowSplit, &tmp) ||
+ (tmp != fc_output_split)) {
+ return false;
+ }
+ // Fully connected output split
+ Operator* fully_connected;
+ if (!MatchOperatorInputs(*fc_output_split, *model, OperatorType::kNone,
+ nullptr, OperatorType::kFullyConnected,
+ &fully_connected)) {
+ return false;
+ }
+
+ // Fully connected op
+ Operator* concat_inputs;
+ if (!MatchOperatorInputs(*fully_connected, *model,
+ OperatorType::kConcatenation, &concat_inputs,
+ OperatorType::kNone, nullptr, OperatorType::kNone,
+ nullptr)) {
+ return false;
+ }
+
+ // Emplace a new LSTM cell operator
+ auto* lstm_cell_op = new LstmCellOperator;
+ lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS);
+ lstm_cell_op->inputs[LstmCellOperator::DATA_INPUT] = concat_inputs->inputs[0];
+ lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] =
+ concat_inputs->inputs[1];
+ lstm_cell_op->inputs[LstmCellOperator::WEIGHTS_INPUT] =
+ fully_connected->inputs[1];
+ lstm_cell_op->inputs[LstmCellOperator::BIASES_INPUT] =
+ fully_connected->inputs[2];
+ lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state;
+ lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS);
+ lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] =
+ state_output_tanh->inputs[0];
+ lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] =
+ final_output_mul->outputs[0];
+ model->operators.emplace(op_it, lstm_cell_op);
+ AddMessageF("Creating %s replacing equivalent subgraph",
+ LogName(*lstm_cell_op));
+
+ // Create temp arrays used internally during runtime.
+ const string base_name(FindLongestCommonPrefix(
+ lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT],
+ lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT]));
+ const string& concat_temp_array_name =
+ AvailableArrayName(*model, base_name + "concat_temp");
+ model->GetOrCreateArray(concat_temp_array_name);
+ lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name;
+ const string& activ_temp_array_name =
+ AvailableArrayName(*model, base_name + "activ_temp");
+ model->GetOrCreateArray(activ_temp_array_name);
+ lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = activ_temp_array_name;
+ AddMessageF("Created temp outputs %s and %s on operator %s",
+ concat_temp_array_name, activ_temp_array_name,
+ LogName(*lstm_cell_op));
+
+ // Delete arrays and operators replaced by the LSTM cell operator. Order is
+ // important - DeleteArrayIfUnused() only succeeds if dependent operators
+ // have been removed first. Start at the output and work towards the input.
+ model->operators.erase(FindOperator(model, *final_output_mul));
+ DeleteArrayIfUnused(state_output_tanh->outputs[0], model);
+ DeleteArrayIfUnused(fc_output_sig->outputs[0], model);
+ model->operators.erase(FindOperator(model, *state_output_tanh));
+ model->operators.erase(FindOperator(model, *fc_output_sig));
+ model->operators.erase(FindOperator(model, *state_combine_add));
+ DeleteArrayIfUnused(state_forget_mul->outputs[0], model);
+ DeleteArrayIfUnused(state_remember_mul->outputs[0], model);
+ model->operators.erase(FindOperator(model, *state_forget_mul));
+ model->operators.erase(FindOperator(model, *state_remember_mul));
+ DeleteArrayIfUnused(state_forget_sig->outputs[0], model);
+ DeleteArrayIfUnused(state_info_tanh->outputs[0], model);
+ DeleteArrayIfUnused(state_remember_sig->outputs[0], model);
+ model->operators.erase(FindOperator(model, *state_forget_sig));
+ model->operators.erase(FindOperator(model, *state_info_tanh));
+ model->operators.erase(FindOperator(model, *state_remember_sig));
+ DeleteArrayIfUnused(fc_output_split->outputs[0], model);
+ DeleteArrayIfUnused(fc_output_split->outputs[1], model);
+ DeleteArrayIfUnused(fc_output_split->outputs[2], model);
+ DeleteArrayIfUnused(fc_output_split->outputs[3], model);
+ string dims_array = fc_output_split->inputs[0];
+ model->operators.erase(FindOperator(model, *fc_output_split));
+ DeleteArrayIfUnused(dims_array, model);
+ DeleteArrayIfUnused(fully_connected->outputs[0], model);
+ model->operators.erase(FindOperator(model, *fully_connected));
+ DeleteArrayIfUnused(concat_inputs->outputs[0], model);
+ model->operators.erase(FindOperator(model, *concat_inputs));
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
new file mode 100644
index 0000000000..cfc77024e7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
@@ -0,0 +1,103 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
+ Model* model, const Operator* op) {
+ auto it = model->operators.begin();
+ for (; it != model->operators.end(); ++it) {
+ if (it->get() == op) {
+ break;
+ }
+ }
+ return it;
+}
+
+bool CheckArrayIsScalarFloat(Model* model, const std::string& name, float val) {
+ const auto& op_array = model->GetArray(name);
+ if (!op_array.buffer || op_array.buffer->type != ArrayDataType::kFloat ||
+ RequiredBufferSizeForShape(op_array.shape()) != 1) {
+ return false;
+ }
+ const auto& op_data = op_array.GetBuffer<ArrayDataType::kFloat>().data;
+ return op_data[0] == val;
+}
+
+// Returns index of scalar input when there is exactly one scalar, -1 otherwise
+int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op,
+ float val) {
+ bool input0_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[0], val);
+ bool input1_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[1], val);
+ return input0_is_scalar == input1_is_scalar ? -1 : input0_is_scalar ? 0 : 1;
+}
+} // namespace
+
+bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
+ const auto maximum_it = model->operators.begin() + op_index;
+ const auto* maximum_op = maximum_it->get();
+ if (maximum_op->type != OperatorType::kTensorFlowMaximum) {
+ return false;
+ }
+ CHECK_EQ(maximum_op->inputs.size(), 2);
+ if (maximum_op->outputs.size() != 1) {
+ return false;
+ }
+ int scalar_input_index =
+ GetSingleScalarInputIndexOfBinaryOp(model, maximum_op, -1.0f);
+ if (scalar_input_index == -1) {
+ return false;
+ }
+ const auto* minimum_op = GetOpWithInput(*model, maximum_op->outputs[0]);
+ if (!minimum_op || minimum_op->type != OperatorType::kTensorFlowMinimum) {
+ return false;
+ }
+ if (GetSingleScalarInputIndexOfBinaryOp(model, minimum_op, 1.0f) == -1) {
+ return false;
+ }
+ CHECK_EQ(minimum_op->inputs.size(), 2);
+
+ // Create and emplace Relu1 node
+ auto* relu1_op = new Relu1Operator;
+ relu1_op->inputs = {maximum_op->inputs[!scalar_input_index]};
+ relu1_op->outputs = minimum_op->outputs;
+ model->operators.emplace(maximum_it, relu1_op);
+
+ AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op));
+
+ // Erase Maximum scalar input & operator
+ model->arrays.erase(maximum_op->inputs[scalar_input_index]);
+ model->operators.erase(FindOperator(model, maximum_op));
+
+ // Erase Minimum inputs & operator
+ model->arrays.erase(minimum_op->inputs[0]);
+ model->arrays.erase(minimum_op->inputs[1]);
+ model->operators.erase(FindOperator(model, minimum_op));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
new file mode 100644
index 0000000000..d83603e9a2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
@@ -0,0 +1,120 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+// This inserts an operator whose output is a float array (name:
+// flags.input_array()). It has to wait for any existing operators that
+// generate this output to be removed by graph transformations. Note that there
+// may be more than one operator that takes the input_array as their input, and
+// that some of these may be removed by graph transformations.
+bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op,
+ GraphTransformation* transformation,
+ Model* model) {
+ // An operator with the required output may be a dequantize operator already
+ // created. Alternatively it may be an operator that needs to be removed
+ // because it is unused, in which case we wait for RemoveUnusedOp to do its
+ // work.
+ if (GetOpWithOutput(*model, input_name)) {
+ return false;
+ }
+
+ // We only apply for the first operator if there is more than one. This is
+ // not strictly necessary for ordering correctness, since we insert the
+ // dequant operator at the beginning of the op sequence, but it makes the
+ // insertion more predictable (eg forward vs backwards operator sweep).
+ if (CountOpsWithInput(*model, input_name) > 1) {
+ if (op != GetFirstOpWithInput(*model, input_name)) {
+ return false;
+ }
+ }
+
+ auto& input_array = model->GetArray(input_name);
+ if (input_array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+
+ if (input_array.final_data_type == input_array.data_type ||
+ input_array.final_data_type == ArrayDataType::kNone) {
+ return false;
+ }
+
+ const auto& dequantized_input_name =
+ AvailableArrayName(*model, input_name + "_dequantized");
+ for (auto& other_op : model->operators) {
+ for (string& other_op_input : other_op->inputs) {
+ if (other_op_input == input_name) {
+ other_op_input = dequantized_input_name;
+ }
+ }
+ }
+
+ auto& dequantized_input_array =
+ model->GetOrCreateArray(dequantized_input_name);
+ auto* image_input_op = new DequantizeOperator;
+ image_input_op->inputs = {input_name};
+ image_input_op->outputs = {dequantized_input_name};
+ model->operators.emplace(model->operators.begin(), image_input_op);
+
+ CHECK(input_array.final_data_type == ArrayDataType::kUint8);
+ input_array.data_type = ArrayDataType::kUint8;
+ dequantized_input_array.data_type = ArrayDataType::kFloat;
+ const auto& input_minmax = input_array.GetMinMax();
+ auto& dequantized_input_minmax = dequantized_input_array.GetOrCreateMinMax();
+ dequantized_input_minmax = input_minmax;
+ auto& input_qparams = input_array.GetOrCreateQuantizationParams();
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(
+ model->flags, input_minmax, &input_qparams);
+
+ transformation->AddMessageF(
+ "Created %s"
+ " to handle quantized input image data, taking over existing"
+ " mean_value and std_value flags. Cleared those flags.",
+ LogName(*image_input_op));
+
+ return true;
+}
+
+bool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) {
+ // This is effectively a transformation applied to edges. We iterate over the
+ // specified node (op) and proceed for input edges.
+ const auto it = model->operators.begin() + op_index;
+ const auto* op = it->get();
+ bool change_made = false;
+ for (auto& input : op->inputs) {
+ for (auto& input_array : *model->flags.mutable_input_arrays()) {
+ if (input_array.name() == input) {
+ if (AddDequantizeOperatorToInput(input_array.name(), op, this, model)) {
+ change_made = true;
+ input_array.clear_mean_value();
+ input_array.clear_std_value();
+ }
+ }
+ }
+ }
+ return change_made;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
new file mode 100644
index 0000000000..1ff4e827aa
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -0,0 +1,142 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+ArrayDataType CommonDataTypeOfAllInputs(const Model& model,
+ const Operator& op) {
+ CHECK_GT(op.inputs.size(), 0);
+ const ArrayDataType data_type = model.GetArray(op.inputs[0]).data_type;
+ for (const auto& input : op.inputs) {
+ const auto& array = model.GetArray(input);
+ CHECK(array.data_type == data_type)
+ << " Unexpected: this operator has inputs with different data types.";
+ }
+ return data_type;
+}
+
+void SetDataTypeForAllOutputs(Model* model, Operator* op,
+ ArrayDataType data_type) {
+ for (const auto& output : op->outputs) {
+ model->arrays[output]->data_type = data_type;
+ }
+}
+} // namespace
+
+bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+
+ // If the data type of some input is unknown, we need to yield.
+ for (const auto& input : op->inputs) {
+ if (model->arrays[input]->data_type == ArrayDataType::kNone) {
+ return false;
+ }
+ }
+ // Record data types of output before processing, so we can see at the
+ // end if we changed anything, and return the correct boolean value.
+ std::unordered_map<string, ArrayDataType> old_output_data_types;
+ for (const auto& output : op->outputs) {
+ old_output_data_types[output] = model->arrays[output]->data_type;
+ }
+ // Do the actual output data types propagation.
+ if (op->type == OperatorType::kDequantize ||
+ op->type == OperatorType::kResizeBilinear) {
+ // These operators unconditionally produce float outputs
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat);
+ } else if (op->type == OperatorType::kTensorFlowLess ||
+ op->type == OperatorType::kTensorFlowLessEqual ||
+ op->type == OperatorType::kTensorFlowGreater ||
+ op->type == OperatorType::kTensorFlowGreaterEqual) {
+ // These operators unconditionally produce bool outputs
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
+ } else if (op->type == OperatorType::kTensorFlowShape) {
+ // These operators are assumed to produce int32 outputs.
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
+ } else if (op->type == OperatorType::kAveragePool ||
+ op->type == OperatorType::kMaxPool ||
+ op->type == OperatorType::kL2Pool ||
+ op->type == OperatorType::kConv ||
+ op->type == OperatorType::kDepthwiseConv ||
+ op->type == OperatorType::kFullyConnected ||
+ op->type == OperatorType::kTensorFlowMax ||
+ op->type == OperatorType::kTensorFlowMin ||
+ op->type == OperatorType::kPad ||
+ op->type == OperatorType::kStridedSlice ||
+ op->type == OperatorType::kTensorFlowReshape ||
+ op->type == OperatorType::kSlice ||
+ op->type == OperatorType::kSqueeze ||
+ op->type == OperatorType::kTensorFlowSum ||
+ op->type == OperatorType::kTensorFlowSwitch ||
+ op->type == OperatorType::kTensorFlowTile ||
+ op->type == OperatorType::kTensorFlowAll ||
+ op->type == OperatorType::kReorderAxes ||
+ op->type == OperatorType::kTensorFlowConcatV2 ||
+ op->type == OperatorType::kFloor ||
+ op->type == OperatorType::kGather ||
+ op->type == OperatorType::kSpaceToBatchND ||
+ op->type == OperatorType::kBatchToSpaceND ||
+ op->type == OperatorType::kMean) {
+ // These operators produce outputs with the same type as their 1st input
+ CHECK_GT(op->inputs.size(), 0);
+ const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ } else if (op->type == OperatorType::kTensorFlowSplit ||
+ op->type == OperatorType::kTensorFlowConcat) {
+ // These operators produce an output with the same type as their 2nd input
+ CHECK_GT(op->inputs.size(), 1);
+ const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ } else if (op->type == OperatorType::kCast) {
+ // Data type of the Cast op is specified.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto* cast_op = static_cast<CastOperator*>(op);
+ model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type;
+ } else if (op->type == OperatorType::kTensorFlowUnsupported) {
+ auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
+ if (unsupported_op->output_data_types.size() != op->outputs.size()) {
+ return false;
+ }
+ for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) {
+ auto output = op->outputs[i];
+ auto data_type = unsupported_op->output_data_types[i];
+ model->arrays[output]->data_type = data_type;
+ }
+ } else {
+ // These operators produce an output with the same type as any of their
+ // inputs, which must always have the same type.
+ const ArrayDataType data_type = CommonDataTypeOfAllInputs(*model, *op);
+ SetDataTypeForAllOutputs(model, op, data_type);
+ }
+ // Return true if any output data type changed, false if none changed.
+ for (const auto& output : op->outputs) {
+ if (old_output_data_types[output] != model->arrays[output]->data_type) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
new file mode 100644
index 0000000000..82a43bc2ce
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -0,0 +1,1129 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth,
+ int kheight, int stride_width, int stride_height,
+ PaddingType padding_type, Shape* output_shape,
+ FixedPadding* fixed_padding) {
+ const int input_width = input_shape.dims(2);
+ const int input_height = input_shape.dims(1);
+ const int batch = input_shape.dims(0);
+
+ int output_height = 0;
+ int output_width = 0;
+ if (padding_type == PaddingType::kValid) {
+ output_height = (input_height + stride_height - kheight) / stride_height;
+ output_width = (input_width + stride_width - kwidth) / stride_width;
+ } else if (padding_type == PaddingType::kSame) {
+ output_height = (input_height + stride_height - 1) / stride_height;
+ output_width = (input_width + stride_width - 1) / stride_width;
+ } else {
+ LOG(FATAL) << "Only supporting SAME or VALID padding";
+ }
+
+ fixed_padding->height =
+ ((output_height - 1) * stride_height + kheight - input_height) / 2;
+ fixed_padding->width =
+ ((output_width - 1) * stride_width + kwidth - input_width) / 2;
+
+ // Actually had to debug a situation where those were negative due to bad
+ // propagation of placeholder -1 sizes in TensorFlowReshape.
+ CHECK_GT(output_width, 0);
+ CHECK_GT(output_height, 0);
+ output_shape->ReplaceDims({batch, output_height, output_width, output_depth});
+}
+
+void ComputeBinaryOperatorOutputSize(const Shape& input_shape1,
+ const Shape& input_shape2,
+ Array* output_array) {
+ const int size1 = RequiredBufferSizeForShape(input_shape1);
+ const int size2 = RequiredBufferSizeForShape(input_shape2);
+ if (size1 > size2) {
+ output_array->copy_shape(input_shape1);
+ } else if (size2 > size1) {
+ output_array->copy_shape(input_shape2);
+ } else {
+ CHECK_EQ(size1, size2);
+ const int dims1 = input_shape1.dimensions_count();
+ const int dims2 = input_shape2.dimensions_count();
+ if (dims1 >= dims2) {
+ output_array->copy_shape(input_shape1);
+ } else {
+ output_array->copy_shape(input_shape2);
+ }
+ }
+ CHECK(output_array->has_shape());
+}
+
+int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
+ const string& weights_name = op.inputs[1];
+ const auto& weights_shape = model.arrays.at(weights_name)->shape();
+ if (op.type == OperatorType::kConv ||
+ op.type == OperatorType::kFullyConnected) {
+ return weights_shape.dims(0);
+ } else if (op.type == OperatorType::kDepthwiseConv) {
+ return weights_shape.dims(3);
+ } else {
+ LOG(FATAL) << "Unhandled operator type";
+ }
+}
+
+bool EnsureBiasVectorShape(Model* model, Operator* op) {
+ const string& weights_name = op->inputs[1];
+ const auto& weights_array = *model->arrays[weights_name];
+ // Yield until weights shape has been resolved.
+ if (!weights_array.has_shape()) {
+ return false;
+ }
+
+ if (op->inputs.size() < 3) {
+ return false;
+ }
+ auto& bias_array = *model->arrays[op->inputs[2]];
+ if (bias_array.has_shape()) {
+ return true;
+ }
+
+ const int output_depth = GetOutputDepthFromWeights(*model, *op);
+ bias_array.copy_shape(Shape({output_depth}));
+
+ auto& float_buffer = bias_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ float_buffer.data.resize(output_depth, 0);
+
+ return true;
+}
+
+void ProcessConvOperator(Model* model, ConvOperator* op) {
+ if (!EnsureBiasVectorShape(model, op)) {
+ return;
+ }
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+
+ const auto& weights_array = *model->arrays[op->inputs[1]];
+ // Yield until weights dims have been resolved.
+ if (!weights_array.has_shape()) {
+ return;
+ }
+ const auto& weights_shape = weights_array.shape();
+ CHECK_EQ(weights_shape.dimensions_count(), 4);
+
+ auto& output_array = model->GetArray(op->outputs[0]);
+ const int output_depth = weights_shape.dims(0);
+ const int kheight = weights_shape.dims(1);
+ const int kwidth = weights_shape.dims(2);
+ ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
+ op->stride_height, op->padding.type,
+ output_array.mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+ CHECK_EQ(output_array.shape().dimensions_count(), 4);
+
+ // Set im2col array dimensions if there is one.
+ if (op->outputs.size() == 2) {
+ const auto& output_shape = output_array.shape();
+ const int input_depth = weights_shape.dims(3);
+ auto& im2col_array = *model->arrays[op->outputs[1]];
+ im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
+ output_shape.dims(2),
+ input_depth * kheight * kwidth});
+ }
+}
+
+void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
+ if (!EnsureBiasVectorShape(model, op)) {
+ return;
+ }
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+
+ const auto& weights_array = *model->arrays[op->inputs[1]];
+ // Yield until weights dims have been resolved.
+ if (!weights_array.has_shape()) {
+ return;
+ }
+ const auto& weights_shape = weights_array.shape();
+ CHECK_EQ(weights_shape.dimensions_count(), 4);
+
+ const string& output_name = op->outputs[0];
+ const int input_depth = input_shape.dims(3);
+ const int output_depth = weights_shape.dims(3);
+ // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops,
+ // instead it has to be inferred from the weights dims. However, once we are
+ // here, weights dims have already been converted to our own internal format,
+ // where the multiplier is no longer readily apparent. So instead we get it
+ // as the quotient of output and input depths. We only want to do that when
+ // depth_multiplier had the zero value: any other value should be checked
+ // as done by the next if() below.
+ if (!op->depth_multiplier) {
+ op->depth_multiplier = output_depth / input_depth;
+ }
+ QCHECK_EQ(output_depth, input_depth * op->depth_multiplier)
+ << "input/output depths and depth_multiplier don't match";
+
+ const int kheight = weights_shape.dims(1);
+ const int kwidth = weights_shape.dims(2);
+ ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
+ op->stride_height, op->padding.type,
+ model->GetArray(output_name).mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+}
+
+void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+
+ const string& output_name = op->outputs[0];
+ const int block_size = op->block_size;
+ CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
+ const int batch = input_shape.dims(0);
+ const int height = input_shape.dims(1);
+ const int width = input_shape.dims(2);
+ const int depth = input_shape.dims(3);
+ QCHECK_EQ(depth % (block_size * block_size), 0);
+
+ model->GetArray(output_name)
+ .copy_shape(Shape({batch, height * block_size, width * block_size,
+ depth / block_size / block_size}));
+}
+
+void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+
+ const string& output_name = op->outputs[0];
+ const int block_size = op->block_size;
+ CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
+ const int batch = input_shape.dims(0);
+ const int height = input_shape.dims(1);
+ const int width = input_shape.dims(2);
+ const int depth = input_shape.dims(3);
+ QCHECK_EQ(width % block_size, 0);
+ QCHECK_EQ(height % block_size, 0);
+
+ model->GetArray(output_name)
+ .copy_shape(Shape({batch, height / block_size, width / block_size,
+ depth * block_size * block_size}));
+}
+
+void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
+ if (!EnsureBiasVectorShape(model, op)) {
+ return;
+ }
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_GE(input_shape.dimensions_count(), 1);
+
+ const auto& weights_array = *model->arrays[op->inputs[1]];
+ // Yield until weights dims have been resolved.
+ if (!weights_array.has_shape()) {
+ return;
+ }
+ const auto& weights_shape = weights_array.shape();
+
+ const int weights_output_depth = weights_shape.dims(0);
+ CHECK_EQ(weights_shape.dimensions_count(), 2);
+
+ const int input_overall_size = RequiredBufferSizeForShape(input_shape);
+ const int matmul_repeats = input_overall_size / weights_shape.dims(1);
+ CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size);
+
+ auto& output_array = model->GetArray(op->outputs[0]);
+ output_array.copy_shape(Shape({matmul_repeats, weights_output_depth}));
+}
+
+void ProcessTensorFlowReshapeOperator(Model* model,
+ TensorFlowReshapeOperator* op) {
+ auto& output_array = *model->arrays[op->outputs[0]];
+ // Bail if we already have output dims
+ if (output_array.has_shape()) {
+ return;
+ }
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+
+ const string& shape_name = op->inputs[1];
+ auto& shape_array = model->GetArray(shape_name);
+ // Yield until the shape is resolved as a constant array
+ if (!shape_array.buffer) {
+ return;
+ }
+ CHECK(shape_array.data_type == ArrayDataType::kInt32);
+ // shape_data is the raw array of ints describing the shape
+ // in the TensorFlow node. We intentionally make a copy here, rather than
+ // modify wildcards in-place below, because in some graphs, the same shape
+ // array with a wildcard may be referenced from multiple Reshape nodes, where
+ // the wildcard needs to resolved to distinct values.
+ std::vector<int32> shape_data =
+ shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ // The Reshape shape may have a wildcard dim, encoded as -1.
+ bool has_wildcard = false;
+ int wildcard_index = 0;
+ int product_non_wildcard_dims = 1;
+ for (int i = 0; i < shape_data.size(); i++) {
+ if (shape_data[i] == -1) {
+ CHECK(!has_wildcard);
+ has_wildcard = true;
+ wildcard_index = i;
+ } else {
+ product_non_wildcard_dims *= shape_data[i];
+ }
+ }
+ const int input_flat_size = RequiredBufferSizeForShape(input_shape);
+ if (has_wildcard) {
+ shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
+ }
+ auto& output_shape = *output_array.mutable_shape();
+ *output_shape.mutable_dims() = shape_data;
+ const int output_flat_size = RequiredBufferSizeForShape(output_shape);
+ CHECK_EQ(output_flat_size, input_flat_size);
+}
+
+void ProcessSimpleOperator(Model* model, Operator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+
+ const string& output_name = op->outputs[0];
+ auto& output_array = *model->arrays[output_name];
+ if (output_array.has_shape()) {
+ return;
+ }
+
+ output_array.copy_shape(input_array.shape());
+}
+
+void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ const auto& input0_array = *model->arrays[op->inputs[0]];
+ const auto& input1_array = *model->arrays[op->inputs[1]];
+ // Yield until input dims have been resolved.
+ if (!input0_array.has_shape() || !input1_array.has_shape()) {
+ return;
+ }
+ const string& output_name = op->outputs[0];
+ auto& output_array = *model->arrays[output_name];
+ ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
+ &output_array);
+}
+
+void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
+ CHECK_LE(op->inputs.size(), 2);
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) {
+ return;
+ }
+ if (op->inputs.size() == 2) {
+ // There is a reduction_indices input.
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& reduction_array = *model->arrays[op->inputs[1]];
+ if (!reduction_array.buffer) {
+ return;
+ }
+ if (!input_array.has_shape()) {
+ return;
+ }
+ auto& input_shape = input_array.shape();
+ CHECK(reduction_array.buffer->type == ArrayDataType::kInt32);
+ const auto& reduction_array_vals =
+ reduction_array.GetBuffer<ArrayDataType::kInt32>().data;
+ auto& output_dims = *output_array.mutable_shape()->mutable_dims();
+ output_dims.clear();
+ for (int i = 0; i < input_shape.dimensions_count(); i++) {
+ bool is_reduction_dim = false;
+ for (int r : reduction_array_vals) {
+ if (i == r) {
+ is_reduction_dim = true;
+ }
+ }
+ if (!is_reduction_dim) {
+ output_dims.push_back(input_shape.dims(i));
+ }
+ }
+ } else {
+ // No reduction_indices means complete reduction to a single scalar.
+ output_array.copy_shape(Shape({}));
+ }
+}
+
+void ProcessSliceOperator(Model* model, SliceOperator* op) {
+ CHECK_EQ(op->inputs.size(), 3);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ // Yield until the Slice params have been resolved.
+ if (op->begin.empty()) return;
+
+ // Yield until input dims have been resolved.
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ if (!input_array.has_shape()) return;
+ const Shape& input_shape = input_array.shape();
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ CHECK_EQ(input_shape.dims().size(), op->size.size());
+ CHECK_EQ(op->begin.size(), op->size.size());
+
+ std::vector<int> output_dims;
+ for (int i = 0; i < op->begin.size(); ++i) {
+ int size = op->size[i];
+ if (size == -1) {
+ size = input_array.shape().dims(i) - op->begin[i];
+ }
+ output_dims.push_back(size);
+ }
+
+ *output_array.mutable_shape()->mutable_dims() = output_dims;
+}
+
+void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
+ const string& input_name = op->inputs[0];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ const string& output_name = op->outputs[0];
+ Shape* output_shape = model->GetArray(output_name).mutable_shape();
+ ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order,
+ output_shape);
+}
+
+void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
+ // Yield until input dims have been resolved.
+ for (const auto& input_name : op->inputs) {
+ auto& input_array = *model->arrays[input_name];
+ if (!input_array.has_shape()) {
+ return;
+ }
+ }
+ auto& output_array = model->GetArray(op->outputs[0]);
+ // Use 0 input as basis for output dimensions.
+ const auto& first_input_array = *model->arrays[op->inputs[0]];
+ output_array.copy_shape(first_input_array.shape());
+ // Determine the concat size, and enfore that all inputs have
+ // the same dimensions count.
+ int concat_size = 0;
+ for (const auto& input_name : op->inputs) {
+ auto& input_array = *model->arrays[input_name];
+ CHECK(input_array.has_shape());
+ if (input_array.shape().dimensions_count() == 0) {
+ continue;
+ }
+ CHECK_EQ(input_array.shape().dimensions_count(),
+ output_array.shape().dimensions_count());
+ const std::vector<int>& input_dims = input_array.shape().dims();
+ CHECK_LT(op->concat_dim, input_dims.size());
+ concat_size += input_dims[op->concat_dim];
+ }
+ // Write out the concat_size on the output array shape.
+ auto& output_shape = *output_array.mutable_shape();
+ auto& output_dims = *output_shape.mutable_dims();
+ CHECK_LT(op->concat_dim, output_shape.dimensions_count());
+ output_dims[op->concat_dim] = concat_size;
+}
+
+void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ const string& input_name = op->inputs[1];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const Shape& input_shape = input_array.shape();
+
+ // This code is slightly suspect. The TensorFlow docs say that the axis
+ // selection defaults to 0, but we are splitting across the final axis.
+ const int input_dims_count = input_shape.dimensions_count();
+ const int input_depth = input_shape.dims(input_dims_count - 1);
+ CHECK_EQ(input_depth % op->num_split, 0);
+ const int split_depth = input_depth / op->num_split;
+
+ Shape output_shape = input_shape;
+ (*output_shape.mutable_dims())[input_dims_count - 1] = split_depth;
+
+ CHECK_EQ(op->outputs.size(), op->num_split);
+ for (const auto& output : op->outputs) {
+ model->arrays[output]->copy_shape(output_shape);
+ }
+}
+
+void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
+ const string& input_name = op->inputs[0];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+ const string& output_name = op->outputs[0];
+ const int output_depth = input_shape.dims(3);
+ ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
+ op->stride_width, op->stride_height, op->padding.type,
+ model->GetArray(output_name).mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+}
+
+void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
+ const string& input_name = op->inputs[0];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+ const string& output_name = op->outputs[0];
+ const int output_depth = input_shape.dims(3);
+ ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
+ op->stride_width, op->stride_height, op->padding.type,
+ model->GetArray(output_name).mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+}
+
+void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
+ const string& input_name = op->inputs[0];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ if (input_shape.dimensions_count() < 4) {
+ LOG(FATAL) << "missing dimensions for " << input_name;
+ }
+ const string& output_name = op->outputs[0];
+ const int output_depth = input_shape.dims(3);
+ ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
+ op->stride_width, op->stride_height, op->padding.type,
+ model->GetArray(output_name).mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+}
+
+void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ if (!model->arrays[op->inputs[0]]->has_shape() ||
+ !model->arrays[op->inputs[1]]->has_shape()) {
+ return;
+ }
+ const auto& input_data_shape = model->arrays[op->inputs[0]]->shape();
+
+ const string& output_size_name = op->inputs[1];
+ const auto& output_size_array = *model->arrays[output_size_name];
+ CHECK(output_size_array.data_type == ArrayDataType::kInt32);
+ CHECK(output_size_array.has_shape());
+ const auto& output_size_shape = output_size_array.shape();
+ CHECK_EQ(output_size_shape.dimensions_count(), 1);
+ CHECK_EQ(output_size_shape.dims(0), 2);
+ std::vector<int32> output_shape =
+ output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
+ model->arrays[op->outputs[0]]->copy_shape(
+ Shape({input_data_shape.dims(0), output_shape[0], output_shape[1],
+ input_data_shape.dims(3)}));
+}
+
+void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
+ // I/O arrays should be allocated on creation of op.
+ QCHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS);
+ QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS);
+
+ const auto& input_array =
+ *model->arrays[op->inputs[LstmCellOperator::DATA_INPUT]];
+ // Yield until all input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_GE(input_shape.dimensions_count(), 2);
+
+ const auto& prev_activ_array =
+ *model->arrays[op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]];
+ // Yield until all input dims have been resolved.
+ if (!prev_activ_array.has_shape()) {
+ return;
+ }
+ const auto& prev_activ_shape = prev_activ_array.shape();
+ CHECK_GE(prev_activ_shape.dimensions_count(), 2);
+
+ const auto& weights_array =
+ *model->arrays[op->inputs[LstmCellOperator::WEIGHTS_INPUT]];
+ // Yield until weights dims have been resolved.
+ if (!weights_array.has_shape()) {
+ return;
+ }
+ const auto& weights_shape = weights_array.shape();
+ CHECK_EQ(weights_shape.dimensions_count(), 2);
+
+ const auto& bias_array =
+ *model->arrays[op->inputs[LstmCellOperator::BIASES_INPUT]];
+ // Yield until bias dims have been resolved.
+ if (!bias_array.has_shape()) {
+ return;
+ }
+ const auto& bias_shape = bias_array.shape();
+ CHECK_GE(bias_shape.dimensions_count(), 1);
+
+ const auto& prev_state_array =
+ *model->arrays[op->inputs[LstmCellOperator::PREV_STATE_INPUT]];
+ // Yield until all input dims have been resolved.
+ if (!prev_state_array.has_shape()) {
+ return;
+ }
+ const auto& prev_state_shape = prev_state_array.shape();
+ CHECK_GE(prev_state_shape.dimensions_count(), 2);
+
+ const int fc_output_depth = weights_shape.dims(0);
+ CHECK_EQ(fc_output_depth, bias_shape.dims(0));
+ CHECK_EQ(fc_output_depth % 4, 0);
+ const int depth = fc_output_depth / 4;
+
+ const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1);
+ const int fc_input_depth = weights_shape.dims(1);
+ CHECK_EQ(input_depth + depth, fc_input_depth);
+ Shape output_shape(input_shape);
+ (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth;
+
+ // Set output dimensions
+ model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT])
+ .copy_shape(output_shape);
+ model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT])
+ .copy_shape(output_shape);
+
+ Shape concat_temp_shape(input_shape);
+ (*concat_temp_shape
+ .mutable_dims())[concat_temp_shape.dimensions_count() - 1] =
+ fc_input_depth;
+ model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP])
+ .copy_shape(concat_temp_shape);
+
+ Shape activ_temp_shape(input_shape);
+ (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] =
+ fc_output_depth;
+ model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP])
+ .copy_shape(activ_temp_shape);
+}
+
+void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+ const auto input_height = input_shape.dims(1);
+ const auto input_width = input_shape.dims(2);
+
+ const auto& block_shape_array = *model->arrays[op->inputs[1]];
+ const auto& paddings_array = *model->arrays[op->inputs[2]];
+ const auto& block_shape_array_shape = block_shape_array.shape();
+ const auto& paddings_array_shape = paddings_array.shape();
+ QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
+ QCHECK_EQ(paddings_array_shape.dimensions_count(), 2);
+
+ // We only support two dimensions.
+ QCHECK_EQ(block_shape_array_shape.dims(0), 2);
+ if (!block_shape_array.buffer) {
+ return;
+ }
+ QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
+ const auto& block_shape_data =
+ block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ auto block_height = block_shape_data[0];
+ auto block_width = block_shape_data[1];
+
+ QCHECK_EQ(paddings_array_shape.dims(0), 2); // Number of block dimensions
+ QCHECK_EQ(paddings_array_shape.dims(1), 2); // Two parameters per dimension.
+ if (!paddings_array.buffer) {
+ return;
+ }
+ QCHECK(paddings_array.data_type == ArrayDataType::kInt32);
+ const auto& paddings_data =
+ paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
+ int height_with_paddings = input_height + paddings_data[0] + paddings_data[1];
+ int width_with_paddings = input_width + paddings_data[2] + paddings_data[3];
+ QCHECK_EQ(height_with_paddings % block_height, 0);
+ QCHECK_EQ(width_with_paddings % block_width, 0);
+ int output_height = height_with_paddings / block_height;
+ int output_width = width_with_paddings / block_width;
+
+ model->arrays[op->outputs[0]]->copy_shape(
+ Shape({input_shape.dims(0) * block_height * block_width, output_height,
+ output_width, input_shape.dims(3)}));
+}
+
+void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+ const auto input_height = input_shape.dims(1);
+ const auto input_width = input_shape.dims(2);
+
+ const auto& block_shape_array = *model->arrays[op->inputs[1]];
+ const auto& crops_array = *model->arrays[op->inputs[2]];
+ const auto& block_shape_array_shape = block_shape_array.shape();
+ const auto& crops_array_shape = crops_array.shape();
+ QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
+ QCHECK_EQ(crops_array_shape.dimensions_count(), 2);
+
+ // We only support two dimensions.
+ QCHECK_EQ(block_shape_array_shape.dims(0), 2);
+ if (!block_shape_array.buffer) {
+ return;
+ }
+ QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
+ const auto& block_shape_data =
+ block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ auto block_height = block_shape_data[0];
+ auto block_width = block_shape_data[1];
+
+ QCHECK_EQ(crops_array_shape.dims(0), 2); // Number of block dimensions
+ QCHECK_EQ(crops_array_shape.dims(1), 2); // Two parameters per dimension.
+ if (!crops_array.buffer) {
+ return;
+ }
+ QCHECK(crops_array.data_type == ArrayDataType::kInt32);
+ const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data;
+ // We don't support crops now.
+ QCHECK_EQ(crops_data[0], 0);
+ QCHECK_EQ(crops_data[1], 0);
+ QCHECK_EQ(crops_data[2], 0);
+ QCHECK_EQ(crops_data[3], 0);
+
+ QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0);
+
+ int output_height = input_height * block_height;
+ int output_width = input_width * block_width;
+
+ model->arrays[op->outputs[0]]->copy_shape(
+ Shape({input_shape.dims(0) / (block_height * block_width), output_height,
+ output_width, input_shape.dims(3)}));
+}
+
+void ProcessGatherOperator(Model* model, GatherOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& indices_array = *model->arrays[op->inputs[1]];
+ auto& output_array = *model->arrays[op->outputs[0]];
+
+ // Bail if we already know the output shape.
+ if (output_array.has_shape()) {
+ return;
+ }
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape() || !indices_array.has_shape()) {
+ return;
+ }
+
+ const auto& input_shape = input_array.shape();
+ const auto& indices_shape = indices_array.shape();
+ QCHECK_GE(input_shape.dimensions_count(), 1);
+ op->input_rank = input_shape.dimensions_count();
+
+ // We only support 1-D indices.
+ QCHECK_EQ(indices_shape.dimensions_count(), 1);
+
+ // Copy the input dimensions to the output except for dimension 0,
+ // where the dimension of indices_shape is used.
+ auto output_dims = output_array.mutable_shape()->mutable_dims();
+ output_dims->push_back(indices_shape.dims(0));
+ for (int dim = 1; dim < input_shape.dimensions_count(); dim++) {
+ output_dims->push_back(input_shape.dims(dim));
+ }
+}
+
+void ProcessPadOperator(Model* model, PadOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) return;
+
+ if (op->left_padding.empty()) return;
+ CHECK_EQ(op->left_padding.size(), op->right_padding.size());
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ Shape output_shape = input_array.shape();
+ std::vector<int>& dims = *output_shape.mutable_dims();
+ CHECK_EQ(op->left_padding.size(), dims.size());
+
+ for (int i = 0; i < op->left_padding.size(); ++i) {
+ dims[i] += op->left_padding[i] + op->right_padding[i];
+ }
+
+ output_array.copy_shape(output_shape);
+}
+
+void ProcessMeanOperator(Model* model, MeanOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) return;
+ const std::vector<int>& indices = op->reduction_indices;
+ if (indices.empty()) return;
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ const std::vector<int>& input_dims = input_array.shape().dims();
+ std::vector<int> output_dims;
+ for (int i = 0; i < input_dims.size(); ++i) {
+ if (std::find(indices.begin(), indices.end(), i) == indices.end()) {
+ output_dims.push_back(input_dims[i]);
+ }
+ }
+ CHECK(!output_dims.empty());
+ CHECK_EQ(output_dims.size(), 2);
+
+ *output_array.mutable_shape()->mutable_dims() = output_dims;
+}
+
+void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
+ CHECK_EQ(op->inputs.size(), 4);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) return;
+
+ if (op->start_indices.empty()) return;
+ CHECK_EQ(op->start_indices.size(), op->stop_indices.size());
+ CHECK_EQ(op->start_indices.size(), op->strides.size());
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ Shape output_shape = input_array.shape();
+ std::vector<int>& dims = *output_shape.mutable_dims();
+ CHECK_EQ(op->start_indices.size(), dims.size());
+
+ for (int i = 0; i < op->start_indices.size(); ++i) {
+ const int mask = 1 << i;
+ const int start = (op->begin_mask & mask) ? 0 : op->start_indices[i];
+ const int stop = (op->end_mask & mask) ? input_array.shape().dims()[i]
+ : op->stop_indices[i];
+ dims[i] = (stop - start) / op->strides[i];
+ }
+
+ output_array.copy_shape(output_shape);
+}
+
+void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
+ CHECK_EQ(op->inputs.size(), 1);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) return;
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ const std::vector<int>& input_dims = input_array.shape().dims();
+ std::vector<int> output_dims;
+
+ for (int i = 0; i < input_dims.size(); ++i) {
+ if (input_dims[i] != 1 ||
+ (!op->squeeze_dims.empty() &&
+ std::find(op->squeeze_dims.begin(), op->squeeze_dims.end(), i) ==
+ op->squeeze_dims.end())) {
+ output_dims.push_back(input_dims[i]);
+ }
+ }
+ *output_array.mutable_shape()->mutable_dims() = output_dims;
+}
+
+void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
+ CHECK(op->inputs.size() == 3 || op->inputs.size() == 4);
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ if (!input_array.has_shape()) return;
+
+ auto& weights_feature_array = *model->arrays[op->inputs[1]];
+ if (!weights_feature_array.has_shape()) return;
+
+ const auto& weights_time_array = *model->arrays[op->inputs[2]];
+ if (!weights_time_array.has_shape()) return;
+
+ const bool has_bias = (op->inputs.size() == 4);
+ if (has_bias) {
+ const auto& bias_array = *model->arrays[op->inputs[3]];
+ if (!bias_array.has_shape()) return;
+ }
+
+ const int batch_size = input_array.shape().dims()[0];
+ const int num_units = weights_feature_array.shape().dims()[0];
+ const int memory_size = weights_time_array.shape().dims()[1];
+
+ auto& state_array = model->GetArray(op->outputs[0]);
+ state_array.mutable_shape()->ReplaceDims(
+ {batch_size, memory_size * num_units});
+
+ auto& output_array = model->GetArray(op->outputs[1]);
+ output_array.mutable_shape()->ReplaceDims({batch_size, num_units});
+}
+} // namespace
+
+bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+ std::unordered_map<string, std::vector<int>> old_output_dims;
+ for (const auto& output : op->outputs) {
+ if (model->arrays[output]->has_shape()) {
+ old_output_dims[output] = model->arrays[output]->shape().dims();
+ }
+ }
+
+ switch (op->type) {
+ case OperatorType::kBatchNormalization:
+ case OperatorType::kL2Normalization:
+ case OperatorType::kDequantize:
+ case OperatorType::kRelu:
+ case OperatorType::kRelu1:
+ case OperatorType::kRelu6:
+ case OperatorType::kSoftmax:
+ case OperatorType::kLogistic:
+ case OperatorType::kTanh:
+ case OperatorType::kLocalResponseNormalization:
+ case OperatorType::kTensorFlowIdentity:
+ case OperatorType::kFakeQuant:
+ case OperatorType::kTensorFlowRsqrt:
+ case OperatorType::kTensorFlowSqrt:
+ case OperatorType::kTensorFlowSquare:
+ case OperatorType::kTensorFlowAll:
+ case OperatorType::kTensorFlowAssert:
+ case OperatorType::kCast:
+ case OperatorType::kFloor:
+ ProcessSimpleOperator(model, op);
+ break;
+ case OperatorType::kGather:
+ ProcessGatherOperator(model, static_cast<GatherOperator*>(op));
+ break;
+
+ case OperatorType::kAdd:
+ case OperatorType::kSub:
+ case OperatorType::kMul:
+ case OperatorType::kDiv:
+ case OperatorType::kTensorFlowLess:
+ case OperatorType::kTensorFlowLessEqual:
+ case OperatorType::kTensorFlowGreater:
+ case OperatorType::kTensorFlowMaximum:
+ case OperatorType::kTensorFlowMinimum:
+ case OperatorType::kTensorFlowGreaterEqual:
+ ProcessSimpleBinaryOperator(model, op);
+ break;
+ case OperatorType::kConv:
+ ProcessConvOperator(model, static_cast<ConvOperator*>(op));
+ break;
+ case OperatorType::kDepthwiseConv:
+ ProcessDepthwiseConvOperator(model,
+ static_cast<DepthwiseConvOperator*>(op));
+ break;
+ case OperatorType::kDepthToSpace:
+ ProcessDepthToSpaceOperator(model,
+ static_cast<DepthToSpaceOperator*>(op));
+ break;
+ case OperatorType::kSpaceToDepth:
+ ProcessSpaceToDepthOperator(model,
+ static_cast<SpaceToDepthOperator*>(op));
+ break;
+ case OperatorType::kFullyConnected:
+ ProcessFullyConnectedOperator(model,
+ static_cast<FullyConnectedOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowReshape:
+ ProcessTensorFlowReshapeOperator(
+ model, static_cast<TensorFlowReshapeOperator*>(op));
+ break;
+ case OperatorType::kAveragePool:
+ ProcessAveragePoolOperator(model, static_cast<AveragePoolOperator*>(op));
+ break;
+ case OperatorType::kMaxPool:
+ ProcessMaxPoolOperator(model, static_cast<MaxPoolOperator*>(op));
+ break;
+ case OperatorType::kL2Pool:
+ ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowMin:
+ case OperatorType::kTensorFlowMax:
+ case OperatorType::kTensorFlowSum:
+ ProcessTensorFlowReductionOperator(model, op);
+ break;
+
+ case OperatorType::kSlice:
+ ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
+ break;
+
+ case OperatorType::kTensorFlowTile:
+ // We don't currently implement the propagation of fixed sizes through
+ // a TensorFlow Tile.
+ //
+ // Fortunately, we don't need to: so far, we have only dealt with Tile
+ // or Slice ops in subgraphs that are identified as L2Normalization.
+ // See IdentifyL2Normalization.
+ break;
+ case OperatorType::kTensorFlowSwitch:
+ // We can't know the sizes of the outputs until we have resolved the
+ // predicate, and once we have resolved the predicate, the whole
+ // Switch node will get resolved away.
+ // See ResolveTensorFlowSwitch.
+ break;
+ case OperatorType::kTensorFlowMerge:
+ // No need to bother resolving TensorFlow Merge ops: other graph
+ // transformations will remove them anyway.
+ // See ResolveTensorFlowMerge.
+ break;
+ case OperatorType::kTensorFlowSplit:
+ ProcessTensorFlowSplitOperator(model,
+ static_cast<TensorFlowSplitOperator*>(op));
+ break;
+ case OperatorType::kSqueeze:
+ ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowConcat:
+ case OperatorType::kTensorFlowConcatV2:
+ // Unimplemented, hopefully another graph transformation will
+ // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat
+ // will resolve this node to a DepthConcatenation, or else we have
+ // a more general non-depth concatenation that will hopefully be dropped,
+ // or else at the moment we will abort.
+ break;
+ case OperatorType::kTensorFlowShape:
+ // Unimplemented, hopefully another graph transformation will drop it or
+ // rewrite it.
+ break;
+ case OperatorType::kReorderAxes:
+ ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op));
+ break;
+ case OperatorType::kConcatenation:
+ ProcessConcatenationOperator(model,
+ static_cast<ConcatenationOperator*>(op));
+ break;
+ case OperatorType::kResizeBilinear:
+ ProcessResizeBilinearOperator(model,
+ static_cast<ResizeBilinearOperator*>(op));
+ break;
+ case OperatorType::kLstmCell:
+ ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowMatMul:
+ // MatMul operators are converted to FullyConnected, after which their
+ // shapes are propagated.
+ break;
+ case OperatorType::kSpaceToBatchND:
+ ProcessSpaceToBatchNDOperator(model,
+ static_cast<SpaceToBatchNDOperator*>(op));
+ break;
+ case OperatorType::kBatchToSpaceND:
+ ProcessBatchToSpaceNDOperator(model,
+ static_cast<BatchToSpaceNDOperator*>(op));
+ break;
+ case OperatorType::kPad:
+ ProcessPadOperator(model, static_cast<PadOperator*>(op));
+ break;
+ case OperatorType::kMean:
+ ProcessMeanOperator(model, static_cast<MeanOperator*>(op));
+ break;
+ case OperatorType::kStridedSlice:
+ ProcessStridedSliceOperator(model,
+ static_cast<StridedSliceOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowUnsupported:
+ break;
+ case OperatorType::kSvdf:
+ ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
+ break;
+ default:
+ // Unimplemented, another graph transformation should drop it.
+ LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
+ }
+
+ // Return true if any output dim changed, false if none changed.
+ // Assumption: no transformation clears an output shape, they only add shapes.
+ for (const auto& output : op->outputs) {
+ if (model->arrays[output]->has_shape() &&
+ (old_output_dims[output] != model->arrays[output]->shape().dims())) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
new file mode 100644
index 0000000000..5551755ea7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -0,0 +1,467 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool SupportsQuantization(const Operator& op) {
+ auto type = op.type;
+ if (type == OperatorType::kTensorFlowUnsupported) {
+ auto* unsupported = static_cast<const TensorFlowUnsupportedOperator*>(&op);
+ return unsupported->quantized;
+ }
+ return type == OperatorType::kConv || type == OperatorType::kDepthwiseConv ||
+ type == OperatorType::kFullyConnected ||
+ type == OperatorType::kConcatenation ||
+ type == OperatorType::kL2Normalization || type == OperatorType::kAdd ||
+ type == OperatorType::kAveragePool || type == OperatorType::kMaxPool ||
+ type == OperatorType::kLogistic || type == OperatorType::kSoftmax ||
+ type == OperatorType::kTensorFlowReshape ||
+ type == OperatorType::kMul || type == OperatorType::kSpaceToDepth ||
+ type == OperatorType::kDepthToSpace;
+}
+
+template <ArrayDataType A>
+std::unique_ptr<GenericBuffer> QuantizeBuffer(
+ const GenericBuffer& buffer,
+ const QuantizationParams& quantization_params) {
+ const auto inverse_scale = 1. / quantization_params.scale;
+ CHECK(buffer.type == ArrayDataType::kFloat);
+ const auto& float_buffer =
+ static_cast<const Buffer<ArrayDataType::kFloat>&>(buffer);
+ auto* quantized_buffer = new Buffer<A>;
+ quantized_buffer->data.resize(float_buffer.data.size());
+ const auto qmin = static_cast<int32>(std::numeric_limits<DataType<A>>::min());
+ const auto qmax = static_cast<int32>(std::numeric_limits<DataType<A>>::max());
+ for (std::size_t i = 0; i < float_buffer.data.size(); i++) {
+ const float src_val = float_buffer.data[i];
+ double scaled_val; // Astonishingly, using 'float' degrades accuracy just
+ // enough to make a few tests fail!
+ if (quantization_params.scale == 0) {
+ CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, "
+ << "so all its values should be 0.";
+ scaled_val = quantization_params.zero_point;
+ } else {
+ scaled_val = quantization_params.zero_point + inverse_scale * src_val;
+ }
+ const auto rounded_val = static_cast<int32>(std::round(scaled_val));
+ const auto clamped_val = std::min(qmax, std::max(qmin, rounded_val));
+ quantized_buffer->data[i] = static_cast<DataType<A>>(clamped_val);
+ }
+ return std::unique_ptr<GenericBuffer>(quantized_buffer);
+}
+
+template <ArrayDataType A>
+void QuantizeArray(GraphTransformation* transformation, Model* model,
+ const string& name,
+ const QuantizationParams& quantization_params) {
+ auto& array = model->GetArray(name);
+ CHECK(array.data_type == ArrayDataType::kFloat);
+ CHECK(!array.quantization_params);
+ array.GetOrCreateQuantizationParams() = quantization_params;
+ if (array.buffer) {
+ array.buffer = QuantizeBuffer<A>(*array.buffer, quantization_params);
+ }
+ array.data_type = A;
+ transformation->AddMessageF("Quantized array %s", name);
+}
+
+void QuantizeArray(GraphTransformation* transformation, Model* model,
+ const string& name, ArrayDataType quantized_data_type,
+ const QuantizationParams& quantization_params) {
+ switch (quantized_data_type) {
+ case ArrayDataType::kUint8:
+ return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name,
+ quantization_params);
+ case ArrayDataType::kInt32:
+ return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name,
+ quantization_params);
+ default:
+ LOG(FATAL) << "Unhandled case.";
+ }
+}
+
+const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
+ auto& array = model->GetArray(array_name);
+ // Normally we should have a MinMax recorded on this Array,
+ // so we just use it.
+ if (array.minmax != nullptr) {
+ return *array.minmax;
+ }
+
+ // We don't have a MinMax. That's bad news: we need
+ // the graph to provide MinMax info for all arrays in order
+ // for inference to reproduce faithfully the same quantization
+ // error as the training process had.
+ //
+ // But we still want to support a fallback for constant arrays,
+ // just using the plain min and max computed from array elements.
+ // We should hopefully never rely on that in production, as that
+ // will not give very good accuracy as that typically won't be
+ // exactly what the training process used. But it will be useful
+ // to allow easily trying out quantization even if the graph
+ // lacks some minmax information.
+ if (array.buffer != nullptr) {
+ LOG(WARNING)
+ << "Constant array " << array_name
+ << " lacks MinMax information. To make up for that, we will now compute"
+ << " the MinMax from actual array elements. That will result in"
+ << " quantization parameters that probably do not match whichever "
+ "arithmetic"
+ << " was used during training, and thus will probably be a cause of "
+ "poor"
+ << " inference accuracy.";
+ CHECK(array.buffer->type == ArrayDataType::kFloat);
+ const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data;
+ // We always want [min, max] to contain 0.
+ float min = 0.f;
+ float max = 0.f;
+ for (auto val : data) {
+ min = std::min(min, val);
+ max = std::max(max, val);
+ }
+ auto& minmax = array.GetOrCreateMinMax();
+ minmax.min = min;
+ minmax.max = max;
+ return minmax;
+ }
+
+ LOG(FATAL) << "Array " << array_name
+ << " does not have MinMax information, "
+ "and is not a constant array. Cannot "
+ "proceed with quantization.";
+}
+
+bool ChooseQuantizationForOperatorInput(
+ GraphTransformation* transformation, Model* model, const Operator& op,
+ std::size_t input_index, ArrayDataType* quantized_data_type,
+ QuantizationParams* quantization_params) {
+ const auto& input = op.inputs[input_index];
+ auto& array = model->GetArray(input);
+ if (array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ if (op.type == OperatorType::kConv ||
+ op.type == OperatorType::kDepthwiseConv ||
+ op.type == OperatorType::kFullyConnected) {
+ if (input_index == 2) {
+ // Quantization of bias vector.
+ // We need both of the mandatory inputs (input activations and weights) to
+ // have
+ // been already quantized.
+ const auto& input_activations = model->GetArray(op.inputs[0]);
+ const auto& input_weights = model->GetArray(op.inputs[1]);
+ if (!input_activations.quantization_params ||
+ !input_weights.quantization_params) {
+ return false;
+ }
+ const auto input_activations_scale =
+ input_activations.quantization_params->scale;
+ const auto input_weights_scale = input_weights.quantization_params->scale;
+ quantization_params->scale =
+ input_activations_scale * input_weights_scale;
+ quantization_params->zero_point = 0;
+ *quantized_data_type = ArrayDataType::kInt32;
+ transformation->AddMessageF(
+ "Input array %s is a bias vector. Choosing quantization params "
+ "accordingly.",
+ input);
+ return true;
+ }
+ }
+
+ const MinMax& minmax = GetOrComputeMinMax(model, input);
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax,
+ quantization_params);
+ transformation->AddMessageF(
+ "For input array %s with min=%g"
+ ", max=%g"
+ ", chose to quantize as uint8 with zero_point=%d"
+ ", scale=%g",
+ input, minmax.min, minmax.max, quantization_params->zero_point,
+ quantization_params->scale);
+ *quantized_data_type = ArrayDataType::kUint8;
+ return true;
+}
+
+bool IsExactlyRepresentable(double real_value, ArrayDataType data_type,
+ const QuantizationParams& quantization_params) {
+ const double scaled_value =
+ quantization_params.zero_point + real_value / quantization_params.scale;
+ const double fractional_scaled_value =
+ scaled_value - std::round(scaled_value);
+ if (std::abs(fractional_scaled_value) > 1e-12) {
+ return false;
+ }
+ const double rounded_scaled_value = std::round(scaled_value);
+ if (data_type == ArrayDataType::kUint8) {
+ if (rounded_scaled_value < 0 || rounded_scaled_value > 255) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool ChooseHardcodedQuantizationForOperatorOutput(
+ const Operator& op, ArrayDataType* quantized_data_type,
+ QuantizationParams* quantization_params) {
+ if (op.type == OperatorType::kL2Normalization) {
+ // L2Normalization has range: [-1, 1].
+ // 0 should be exactly representable, as values will typically be centered
+ // around 0, with many values near 0.
+ *quantized_data_type = ArrayDataType::kUint8;
+ quantization_params->zero_point = 128;
+ quantization_params->scale = 1. / 128.;
+ CHECK(
+ IsExactlyRepresentable(0., *quantized_data_type, *quantization_params));
+ return true;
+ }
+ if ((op.type == OperatorType::kLogistic) ||
+ (op.type == OperatorType::kSoftmax)) {
+ // Logistic and Softmax have range: [0, 1].
+ //
+ // For Logistic, 0.5 should be exactly representable, as implementations
+ // will typically exploit the symmetry logistic(-x) = 1 - logistic(x), and
+ // the glueing of the two halves of the graph will only be seamless if we
+ // are accurately representing logistic(0) == 0.5.
+ *quantized_data_type = ArrayDataType::kUint8;
+ quantization_params->zero_point = 0;
+ quantization_params->scale = 1. / 256.;
+ CHECK(IsExactlyRepresentable(0.5, *quantized_data_type,
+ *quantization_params));
+ return true;
+ }
+ return false;
+}
+
+bool ChooseQuantizationForOperatorOutput(
+ GraphTransformation* transformation, Model* model, const Operator& op,
+ std::size_t output_index, ArrayDataType* quantized_data_type,
+ QuantizationParams* quantization_params) {
+ const auto& output = op.outputs[output_index];
+ auto& array = model->GetArray(output);
+ if (array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ if (ChooseHardcodedQuantizationForOperatorOutput(op, quantized_data_type,
+ quantization_params)) {
+ transformation->AddMessageF(
+ "Output array %s is produced by a %s operator. Choosing fixed "
+ "quantization params accordingly.",
+ output, OperatorTypeName(op.type));
+ return true;
+ }
+ if ((op.type == OperatorType::kDepthToSpace) ||
+ (op.type == OperatorType::kSpaceToDepth)) {
+ // DepthToSpace and SpaceToDepth should preserve the quantization parameters
+ // of the input array, as these are simple reshape operations.
+ const auto& input_quantization_params =
+ model->GetArray(op.inputs[0]).GetQuantizationParams();
+ *quantized_data_type = ArrayDataType::kUint8;
+ quantization_params->zero_point = input_quantization_params.zero_point;
+ quantization_params->scale = input_quantization_params.scale;
+
+ transformation->AddMessageF(
+ "Output array %s is produced by a %s operator. Copying quantization "
+ "params from input array.",
+ output, OperatorTypeName(op.type));
+ return true;
+ }
+ const MinMax& minmax = GetOrComputeMinMax(model, output);
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax,
+ quantization_params);
+ *quantized_data_type = ArrayDataType::kUint8;
+ transformation->AddMessageF(
+ "For output array %s with min=%g, max=%g"
+ ", chose to quantize as uint8 with zero_point=%d"
+ ", scale=%g",
+ output, minmax.min, minmax.max, quantization_params->zero_point,
+ quantization_params->scale);
+
+ return true;
+}
+} // namespace
+
+bool Quantize::Run(Model* model, std::size_t op_index) {
+ // Our general "quantization" graph transformation consists in replacing
+ // QuantizedInputArrays[] ->
+ // DequantizeOperators[] ->
+ // FloatInputArrays[] ->
+ // Operator ->
+ // FloatOutputArray
+ // by
+ // QuantizedInputArrays[] ->
+ // Operator ->
+ // QuantizedOutputArray ->
+ // DequantizeOperator ->
+ // FloatOutputArray
+ //
+ // In other words, this is pushing Dequantize operators to the right of
+ // other operators.
+ //
+
+ auto& op = *model->operators[op_index];
+ if (op.type == OperatorType::kDequantize ||
+ op.type == OperatorType::kFakeQuant) {
+ return false;
+ }
+
+ // Our assumption here is that the input arrays are already quantized -
+ // that is typically the case in models operating on an input bitmap
+ // image, and MakeInitialDequantizeOp should have already resolved
+ // the handling of the input image as an initial Dequantize op.
+ //
+ // Thus we are building around the assumption that the graph always starts
+ // with a quantized input array, and only after some Dequantize op do we have
+ // float arrays. The problem of quantizing the graph thus becomes a problem of
+ // pushing Dequantize ops to the right of other ops.
+ //
+ // Let us just guard this assumption by the following assertion:
+ for (const auto& input : op.inputs) {
+ if (IsInputArray(*model, input)) {
+ const auto& input_array = model->GetArray(input);
+ CHECK(input_array.quantization_params);
+ }
+ }
+ if (!SupportsQuantization(op)) {
+ LOG(FATAL) << "Unimplemented: this graph contains an operator of type "
+ << HelpfulOperatorTypeName(op)
+ << " for which the quantized form is not yet implemented. "
+ "Sorry, and patches welcome (that's a relatively fun patch "
+ "to write, mostly providing the actual quantized arithmetic "
+ "code for this op).";
+ }
+
+ for (const auto& input : op.inputs) {
+ const auto& array = model->GetArray(input);
+ if (array.data_type == ArrayDataType::kFloat) {
+ if (!array.minmax && !array.buffer) {
+ LOG(ERROR) << "Can't quantize input array " << input
+ << " because it lacks min/max info";
+ return false;
+ }
+ const auto* other_op = GetOpWithOutput(*model, input);
+ if (other_op && other_op->type != OperatorType::kDequantize) {
+ AddMessageF(
+ "Not quantizing %s for now, because its input array %s is not "
+ "produced by a Dequantize op, "
+ "which means that we should yield and let other ops "
+ "get quantized first",
+ LogName(op), input);
+ return false;
+ }
+ }
+ }
+
+ bool changed = false;
+
+ // Quantize inputs, remove any Dequantize op on the inputs side
+ for (std::size_t input_index = 0; input_index < op.inputs.size();
+ input_index++) {
+ ArrayDataType quantized_data_type;
+ QuantizationParams quantization_params;
+ if (ChooseQuantizationForOperatorInput(this, model, op, input_index,
+ &quantized_data_type,
+ &quantization_params)) {
+ changed = true;
+ const auto& input = op.inputs[input_index];
+ if (IsConstantParameterArray(*model, input)) {
+ QuantizeArray(this, model, input, quantized_data_type,
+ quantization_params);
+ } else {
+ auto dequantize_it = FindOpWithOutput(*model, input);
+ CHECK(dequantize_it != model->operators.end());
+ auto* dequantize_op = dequantize_it->get();
+ CHECK(dequantize_op->type == OperatorType::kDequantize);
+ op.inputs[input_index] = dequantize_op->inputs[0];
+ // Check if the output of that Dequantize op was not used by any
+ // other operator. We will then erase that Dequantize op.
+ if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) {
+ // If any of the model's output_arrays was pointing to the
+ // Dequantize op's output, let it point to the Dequantize op's
+ // input instead.
+ for (int i = 0; i < model->flags.output_arrays_size(); i++) {
+ if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) {
+ model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
+ }
+ }
+ model->arrays.erase(dequantize_op->outputs[0]);
+ model->operators.erase(dequantize_it);
+ }
+ }
+ }
+ }
+
+ // Quantize outputs, add Dequantize ops as needed on the outputs side
+ for (std::size_t output_index = 0; output_index < op.outputs.size();
+ output_index++) {
+ ArrayDataType quantized_data_type;
+ QuantizationParams quantization_params;
+ if (ChooseQuantizationForOperatorOutput(this, model, op, output_index,
+ &quantized_data_type,
+ &quantization_params)) {
+ changed = true;
+ const auto& output = op.outputs[output_index];
+ QuantizeArray(this, model, output, quantized_data_type,
+ quantization_params);
+ const auto& dequantized_output =
+ AvailableArrayName(*model, output + "_dequantized");
+ const auto& output_array = model->GetArray(output);
+ const auto& output_minmax = output_array.GetMinMax();
+ auto& dequantized_output_array =
+ model->GetOrCreateArray(dequantized_output);
+ dequantized_output_array.data_type = ArrayDataType::kFloat;
+ auto& dequantized_output_minmax =
+ dequantized_output_array.GetOrCreateMinMax();
+ dequantized_output_minmax.min = output_minmax.min;
+ dequantized_output_minmax.max = output_minmax.max;
+ for (const auto& other_op : model->operators) {
+ for (auto& other_op_input : other_op->inputs) {
+ if (other_op_input == output) {
+ other_op_input = dequantized_output;
+ }
+ }
+ }
+ auto* dequantize_op = new DequantizeOperator;
+ dequantize_op->inputs = {output};
+ dequantize_op->outputs = {dequantized_output};
+ for (int i = 0; i < model->flags.output_arrays_size(); i++) {
+ if (model->flags.output_arrays(i) == output) {
+ model->flags.set_output_arrays(i, dequantized_output);
+ }
+ }
+ const auto op_it = FindOp(*model, &op);
+ model->operators.emplace(op_it + 1, dequantize_op);
+ }
+ }
+
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
new file mode 100644
index 0000000000..371ced388a
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
@@ -0,0 +1,105 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool ApplyMinMaxToArray(GraphTransformation* transformation, Model* model,
+ const MinMax& minmax, const string& array_name) {
+ auto& annotated_array = model->GetArray(array_name);
+ if (annotated_array.minmax) {
+ return false;
+ }
+ annotated_array.GetOrCreateMinMax() = minmax;
+ transformation->AddMessageF(
+ "Read min/max annotation for array %s: min=%g, max=%g", array_name,
+ minmax.min, minmax.max);
+ return true;
+}
+
+} // end namespace
+
+bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) {
+ const auto fakequant_it = model->operators.begin() + op_index;
+ auto* fakequant_base_op = fakequant_it->get();
+ if (fakequant_base_op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
+
+ bool changed = false;
+
+ if (!fakequant_op->minmax) {
+ CHECK_EQ(fakequant_op->inputs.size(), 3);
+ // We need to yield until the min and max parameters have been
+ // resolved to constant arrays.
+ for (int i = 1; i <= 2; i++) {
+ if (!IsConstantParameterArray(*model, fakequant_op->inputs[1])) {
+ return false;
+ }
+ }
+
+ // Obtain the final min/max values
+ const auto& min_array = model->GetArray(fakequant_op->inputs[1]);
+ const auto& max_array = model->GetArray(fakequant_op->inputs[2]);
+ CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1);
+ CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1);
+ fakequant_op->minmax.reset(new MinMax);
+ MinMax& minmax = *fakequant_op->minmax;
+ minmax.min = min_array.GetBuffer<ArrayDataType::kFloat>().data[0];
+ minmax.max = max_array.GetBuffer<ArrayDataType::kFloat>().data[0];
+ // We always want [min, max] to contain 0.
+ minmax.min = std::min(minmax.min, 0.);
+ minmax.max = std::max(minmax.max, 0.);
+
+ // We won't use the input arrays that provided these min and max
+ // values, anymore. Delete them unless they are used by something
+ // else.
+ for (int i = 1; i <= 2; i++) {
+ if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
+ model->arrays.erase(fakequant_op->inputs[i]);
+ }
+ }
+ fakequant_op->inputs.resize(1);
+ changed = true;
+ }
+
+ // At this point, this FakeQuantOperator should have a MinMax
+ // attached to it, and should only have 1 input (it should not have
+ // 2nd and 3rd input arrays giving min and max anymore).
+ CHECK(fakequant_op->minmax);
+ CHECK_EQ(1, fakequant_op->inputs.size());
+
+ const MinMax& minmax = *fakequant_op->minmax;
+
+ // Record the MinMax info on the input and output arrays
+ changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->inputs[0]);
+ changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->outputs[0]);
+
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
new file mode 100644
index 0000000000..3992e7d1ef
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
@@ -0,0 +1,59 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
+ const auto dequantize_it = model->operators.begin() + op_index;
+ const auto* dequantize_op = dequantize_it->get();
+ if (dequantize_op->type != OperatorType::kDequantize) {
+ return false;
+ }
+ const auto& output = dequantize_op->outputs[0];
+ // We can remove any dequantize op whose output is not consumed by
+ // any op. This is not necessarily equivalent to the output being
+ // one of the model's output arrays, as some intermediate array
+ // in the middle of the graph might be designated as an output
+ // array.
+ if (CountOpsWithInput(*model, output)) {
+ return false;
+ }
+
+ // If one of the model's output arrays was actually the Dequantize op's
+ // output, then we need to update it to point to the Dequantize op's input.
+ for (int i = 0; i < model->flags.output_arrays_size(); i++) {
+ if (output == model->flags.output_arrays(i)) {
+ model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
+ }
+ }
+
+ // Remove the node and its output array.
+ AddMessageF("Removed final %s", LogName(*dequantize_op));
+ model->arrays.erase(output);
+ model->operators.erase(dequantize_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
new file mode 100644
index 0000000000..35a0c46532
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
@@ -0,0 +1,60 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) {
+ const auto assert_it = model->operators.begin() + op_index;
+ const auto* assert_op = assert_it->get();
+ if (assert_op->type != OperatorType::kTensorFlowAssert) {
+ return false;
+ }
+
+ bool changed = false;
+ // Remove any other node's dependency on this assert node
+ for (const auto& op : model->operators) {
+ auto it = op->inputs.begin();
+ while (it != op->inputs.end()) {
+ if (*it == assert_op->outputs[0]) {
+ op->inputs.erase(it);
+ changed = true;
+ } else {
+ ++it;
+ }
+ }
+ }
+ CHECK(!CountOpsWithInput(*model, assert_op->outputs[0]));
+
+ if (changed) {
+ AddMessageF(
+ "Prepared for the removal of %s by removing any other op's dependency "
+ "on it",
+ LogName(*assert_op));
+ }
+
+ // That's it. We can stop here, no need to duplicate the work that
+ // RemoveUnusedOp will do removing this now-unused node.
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
new file mode 100644
index 0000000000..404269bbfd
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) {
+ const auto passthru_it = model->operators.begin() + op_index;
+ const auto* passthru_op = passthru_it->get();
+ if (passthru_op->type != OperatorType::kTensorFlowIdentity) {
+ return false;
+ }
+
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
new file mode 100644
index 0000000000..6add443f2d
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
@@ -0,0 +1,113 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+template <typename Scalar>
+bool AreAllBufferElementsEqualTo(const std::vector<Scalar>& buffer_data,
+ Scalar value) {
+ for (auto x : buffer_data) {
+ if (x != value) {
+ return false;
+ }
+ }
+ return true;
+}
+} // namespace
+
+// A binary operator is called trivial when exactly one of its operands is
+// a constant and is such that the binary operation is equivalent to
+// the identity operation on its other input.
+// For example, an Add operator is trivial if
+// one of its operands is constant 0, a Mul operator is trivial
+// if one of its operands is constant 1, etc.
+bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ auto* binary_op = binary_it->get();
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv) {
+ return false;
+ }
+
+ CHECK_EQ(binary_op->inputs.size(), 2);
+
+ // This graph transformation is only concerned with the case
+ // when one input is constant and the other is not constant.
+ const bool is_input_constant[2] = {
+ IsConstantParameterArray(*model, binary_op->inputs[0]),
+ IsConstantParameterArray(*model, binary_op->inputs[1]),
+ };
+ if (!is_input_constant[0] && !is_input_constant[1]) {
+ // Neither input is constant, so nothing we can resolve here.
+ return false;
+ }
+ if (is_input_constant[0] && is_input_constant[1]) {
+ // Both inputs are constants. That's a job for constants
+ // propagation, not for us to handle here.
+ return false;
+ }
+ const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
+ const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
+ CHECK(is_input_constant[index_of_constant_input]);
+ CHECK(!is_input_constant[index_of_variable_input]);
+
+ // Now check if the constant operand makes this binary
+ // operator trivial.
+ const auto& constant_input_array =
+ *model->arrays[binary_op->inputs[index_of_constant_input]];
+ // For now, we only handle floats here.
+ if (constant_input_array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ const auto& constant_input_float_data =
+ constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
+ bool is_trivial = false;
+ if (binary_op->type != OperatorType::kAdd) {
+ is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 0.f);
+ } else if (binary_op->type != OperatorType::kSub) {
+ is_trivial = index_of_constant_input == 1 &&
+ AreAllBufferElementsEqualTo(constant_input_float_data, 0.f);
+ } else if (binary_op->type != OperatorType::kMul) {
+ is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 1.f);
+ } else if (binary_op->type != OperatorType::kDiv) {
+ is_trivial = index_of_constant_input == 1 &&
+ AreAllBufferElementsEqualTo(constant_input_float_data, 1.f);
+ }
+
+ if (!is_trivial) {
+ return false;
+ }
+
+ // Now we know that this node is trivial, so we can remove it.
+ AddMessageF("Removing trivial %s", LogName(*binary_op));
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc
new file mode 100644
index 0000000000..3ceb93d8ee
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc
@@ -0,0 +1,40 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTrivialConcatenation::Run(Model* model, std::size_t op_index) {
+ const auto concat_it = model->operators.begin() + op_index;
+ auto* concat_op = concat_it->get();
+ if (concat_op->type != OperatorType::kConcatenation) {
+ return false;
+ }
+ if (concat_op->inputs.size() != 1) {
+ return false;
+ }
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
new file mode 100644
index 0000000000..b603735704
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
@@ -0,0 +1,68 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
+ // TensorFlow allows Concatenation nodes to have 0-D inputs,
+ // and they are then treated as empty i.e. omitted from concatenation,
+ // in violation of the notion that 0-D is equivalent to 1x1x1x1.
+ // Thus we have to drop these 0-D inputs from Concatenation nodes.
+ // Sometimes, there will remain only one non-trivial input, and
+ // the other graph transformation RemoveTrivialConcatenation will then drop
+ // it.
+ const auto concat_it = model->operators.begin() + op_index;
+ auto* concat_op = concat_it->get();
+ if (concat_op->type != OperatorType::kConcatenation) {
+ return false;
+ }
+ std::vector<string> trivial_inputs;
+ std::vector<string> nontrivial_inputs;
+ for (const string& input : concat_op->inputs) {
+ const auto& input_array = model->GetArray(input);
+ const bool is_trivial =
+ input_array.has_shape() && input_array.shape().dimensions_count() == 0;
+ if (is_trivial) {
+ trivial_inputs.push_back(input);
+ } else {
+ nontrivial_inputs.push_back(input);
+ }
+ }
+
+ if (trivial_inputs.empty()) {
+ return false;
+ }
+
+ // Drop trivial inputs.
+ for (const string& input : trivial_inputs) {
+ if (CountOpsWithInput(*model, input) == 1) {
+ model->arrays.erase(input);
+ }
+ }
+ concat_op->inputs = nontrivial_inputs;
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
new file mode 100644
index 0000000000..a0d1338298
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
@@ -0,0 +1,107 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+// Reroute all edges involving a given discardable array to another
+// array instead. from_array is assumed to be discardable, and consequently
+// this only updates operator edges (since discardable arrays only
+// appear there, and not e.g. in model flags).
+void RerouteEdges(const string& from_array, const string& to_array,
+ Model* model) {
+ for (const auto& op : model->operators) {
+ for (auto& output : op->outputs) {
+ if (output == from_array) {
+ output = to_array;
+ }
+ }
+ for (auto& input : op->inputs) {
+ if (input == from_array) {
+ input = to_array;
+ }
+ }
+ }
+}
+
+} // end anonymous namespace
+
+bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
+ Model* model, std::size_t op_index) {
+ const auto passthru_it = model->operators.begin() + op_index;
+ auto* passthru_op = passthru_it->get();
+ CHECK_EQ(passthru_op->outputs.size(), 1);
+ CHECK_GE(passthru_op->inputs.size(), 1);
+ int count_nonconstant_input_arrays = 0;
+ // We call 'main input' the unique nonconstant input array if there is one,
+ // or else the 0-th input.
+ int main_input_array_index = 0;
+ for (int i = 0; i < passthru_op->inputs.size(); i++) {
+ if (!model->GetArray(passthru_op->inputs[i]).buffer) {
+ count_nonconstant_input_arrays++;
+ main_input_array_index = i;
+ }
+ }
+ CHECK_LE(count_nonconstant_input_arrays, 1);
+
+ const string main_input_name = passthru_op->inputs[main_input_array_index];
+ const string output_name = passthru_op->outputs[0];
+ if (IsDiscardableArray(*model, output_name)) {
+ transformation->AddMessageF(
+ "Removing %s, keeping its non-constant input array",
+ LogName(*passthru_op));
+ model->arrays.erase(output_name);
+ for (const string& input : passthru_op->inputs) {
+ if (IsDiscardableArray(*model, input) && input != main_input_name &&
+ CountOpsWithInput(*model, input) == 1) {
+ model->arrays.erase(input);
+ }
+ }
+ RerouteEdges(output_name, main_input_name, model);
+ } else if (IsDiscardableArray(*model, main_input_name)) {
+ transformation->AddMessageF("Removing %s, keeping its output array",
+ LogName(*passthru_op));
+ for (const string& input : passthru_op->inputs) {
+ if (IsDiscardableArray(*model, input) &&
+ (input == main_input_name || CountOpsWithInput(*model, input) == 1)) {
+ model->arrays.erase(input);
+ }
+ }
+ RerouteEdges(main_input_name, output_name, model);
+ } else {
+ transformation->AddMessageF(
+ "Cannot remove %s, neither its nonconstant input nor its output may be "
+ "discarded",
+ LogName(*passthru_op));
+ return false;
+ }
+
+ // Remove the pass-through node.
+ model->operators.erase(passthru_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h
new file mode 100644
index 0000000000..b72c85c0e5
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+// A "passthrough op" is an op that satisfies the following conditions:
+// 1. It has at most one non-constant input (it may have other constant
+// inputs).
+// 2. It has exactly one output.
+// 3. It forwards exactly its single non-constant input to its single output.
+//
+// Examples include:
+// 1. TensorFlow Identity ops. (Have one input).
+// 2. TensorFlow Reshape ops when the input and output shapes agree.
+// 3. Any binary operator, one of whose two inputs is a constant and is the
+// neutral value for that operation. For example, a binary Add operator
+// where one of its inputs is a constant array filled with zeros.
+//
+// A passthrough op is "trivial" and can be removed when it is possible to
+// discard either its single non-constant input or output array, rerouting any
+// edge involving it to the other of these two arrays.
+//
+// It is only possible to discard such an array if it is not explicitly
+// designated as a global input/output array of the graph, e.g. the model's
+// input arrays, output arrays, and any array involved in a RNN back-edge
+// specified by the model.
+//
+// This function does not check that the given operator is a passthrough op:
+// that's the responsibility of the caller.
+// Given that it is a passthrough op, this function checks whether it is trivial
+// and then discards it and returns true, or, if it's not trivial (if neither
+// the input nor the output may be discarded), returns false.
+bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
+ Model* model, std::size_t op_index);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc
new file mode 100644
index 0000000000..28f76c9d36
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTrivialQuantizedActivationFunc::Run(Model* model,
+ std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+ if (op->fused_activation_function != FusedActivationFunctionType::kRelu &&
+ op->fused_activation_function != FusedActivationFunctionType::kRelu6) {
+ return false;
+ }
+ const auto& output_array = model->GetArray(op->outputs[0]);
+ if (!output_array.quantization_params) {
+ return false;
+ }
+ if (output_array.data_type != ArrayDataType::kUint8) {
+ return false;
+ }
+ const auto& quantization_params = output_array.GetQuantizationParams();
+
+ bool has_nontrivial_min_bound = false;
+ bool has_nontrivial_max_bound = false;
+
+ if (op->fused_activation_function == FusedActivationFunctionType::kRelu ||
+ op->fused_activation_function == FusedActivationFunctionType::kRelu6) {
+ double lowest_representable_output =
+ (0. - quantization_params.zero_point) * quantization_params.scale;
+ if (lowest_representable_output < 0.) {
+ has_nontrivial_min_bound = true;
+ AddMessageF(
+ "Quantized activation function is not trivial: "
+ "the lowest representable output value %g"
+ " less than the clamp min bound.",
+ lowest_representable_output);
+ }
+ }
+ if (op->fused_activation_function == FusedActivationFunctionType::kRelu6) {
+ double highest_representable_output =
+ (255. - quantization_params.zero_point) * quantization_params.scale;
+ if (highest_representable_output > 6.) {
+ has_nontrivial_max_bound = true;
+ AddMessageF(
+ "Quantized activation function is not trivial: "
+ "the highest representable output value %g"
+ " is greater than the clamp max bound.",
+ highest_representable_output);
+ }
+ }
+
+ if (has_nontrivial_min_bound || has_nontrivial_max_bound) {
+ return false;
+ }
+
+ op->fused_activation_function = FusedActivationFunctionType::kNone;
+ AddMessageF(
+ "Removing trivial quantized activation function on %s"
+ " because the output quantization parameters imply at least as tight"
+ " a clamp anyway.",
+ LogName(*op));
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
new file mode 100644
index 0000000000..90f9381ec1
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
@@ -0,0 +1,92 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool IsReshapeTrivial(const Model& model, const Operator& op,
+ RemoveTrivialReshape* transformation) {
+ CHECK(op.type == OperatorType::kTensorFlowReshape);
+
+ // One way in which a reshape can be trivial is if its
+ // output shape is == its input shape
+ const auto& input_array = model.GetArray(op.inputs[0]);
+ const auto& output_array = model.GetArray(op.outputs[0]);
+ if (input_array.has_shape() && output_array.has_shape()) {
+ if (transformation->treat_expand_dims_as_trivial() &&
+ ShapesAgreeUpToExtending(input_array.shape(), output_array.shape())) {
+ transformation->AddMessageF(
+ "%s is trivial because its input and output shapes are equal up to "
+ "extending "
+ "by 1's, and we are told to aggressively discard such Reshape ops.",
+ LogName(op));
+ return true;
+ }
+ if (input_array.shape().dims() == output_array.shape().dims()) {
+ transformation->AddMessageF(
+ "%s is trivial because its input and output shapes are equal",
+ LogName(op));
+ return true;
+ }
+ }
+
+ // Another way in which a reshape can be trivial is if its output
+ // is only consumed by another reshape.
+ if (CountOpsWithInput(model, op.outputs[0]) == 1) {
+ const auto* next_op = GetOpWithInput(model, op.outputs[0]);
+ if (next_op->type == OperatorType::kTensorFlowReshape) {
+ transformation->AddMessageF(
+ "%s is trivial because its output is only consumed by another "
+ "Reshape op",
+ LogName(op));
+ return true;
+ }
+ }
+
+ return false;
+}
+
+} // namespace
+
+bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) {
+ const auto reshape_it = model->operators.begin() + op_index;
+ auto* reshape_op = reshape_it->get();
+ if (reshape_op->type != OperatorType::kTensorFlowReshape) {
+ return false;
+ }
+
+ if (!IsReshapeTrivial(*model, *reshape_op, this)) {
+ return false;
+ }
+
+ AddMessageF("Removing trivial %s", LogName(*reshape_op));
+
+ CHECK_EQ(reshape_op->inputs.size(), 2);
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
new file mode 100644
index 0000000000..1f1f1f6948
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
@@ -0,0 +1,122 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ const auto* op = it->get();
+
+ // Bail if any output is used, and is not an input_array of
+ // the model. We allow specifying an arbitrary input_array,
+ // treating the part of the graph leading up to it as unused.
+ for (const auto& output : op->outputs) {
+ CHECK(model->arrays.count(output));
+ // If this output is provided as the model's input array,
+ // then we don't need this operator to produce its contents.
+ if (IsInputArray(*model, output)) {
+ continue;
+ }
+ // If this output is provided as a RNN's state array,
+ // then we don't need this operator to produce its contents.
+ // So far this case has only been encountered with TensorFlow
+ // Fill ops used to zero-initialize RNN states, which is
+ // redundant for us as we zero-initialize RNN states anyway.
+ bool found_output_as_rnn_state_array = false;
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (output == rnn_state.state_array()) {
+ CHECK(op->type == OperatorType::kTensorFlowUnsupported);
+ CHECK_EQ(static_cast<const TensorFlowUnsupportedOperator*>(op)
+ ->tensorflow_op,
+ "Fill");
+ found_output_as_rnn_state_array = true;
+ break;
+ }
+ }
+ if (found_output_as_rnn_state_array) {
+ continue;
+ }
+ for (const string& output_array : model->flags.output_arrays()) {
+ if (output == output_array) {
+ return false;
+ }
+ }
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (output == rnn_state.back_edge_source_array()) {
+ return false;
+ }
+ }
+ if (CountOpsWithInput(*model, output)) {
+ return false;
+ }
+ }
+
+ if (op->unresolved_outputs) {
+ AddMessageF("Not discarding %s because it has unresolved outputs.",
+ LogName(*op));
+ return false;
+ }
+
+ AddMessageF("Discarding %s because none of its outputs is used.",
+ LogName(*op));
+
+ // At that point we know that none of the outputs is used, so we will
+ // definitely remove the node and all its outputs.
+
+ // Remove any input array that is not used by anything else,
+ // and that is not the output of some other operator.
+ for (const auto& input : op->inputs) {
+ if (CountOpsWithInput(*model, input) == 1 &&
+ !GetOpWithOutput(*model, input)) {
+ model->arrays.erase(input);
+ }
+ }
+
+ // Remove the node and its now-unused output arrays.
+ for (const auto& output : op->outputs) {
+ // If the output array is the model's input array, don't remove that.
+ // That's the case when cropping a model at a given --input_array.
+ if (IsInputArray(*model, output)) {
+ continue;
+ }
+ // Likewise, if the output array is a RNN state array, don't remove that.
+ bool found_output_as_rnn_state_array = false;
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (output == rnn_state.state_array()) {
+ found_output_as_rnn_state_array = true;
+ break;
+ }
+ }
+ if (found_output_as_rnn_state_array) {
+ continue;
+ }
+ // Generic case: do delete this output array.
+ model->arrays.erase(output);
+ }
+ model->operators.erase(it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
new file mode 100644
index 0000000000..3eb7fa3896
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
@@ -0,0 +1,135 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
+ auto bn_it = model->operators.begin() + op_index;
+ if (bn_it->get()->type != OperatorType::kBatchNormalization) {
+ return false;
+ }
+ const auto* bn_op =
+ static_cast<const BatchNormalizationOperator*>(bn_it->get());
+
+ const auto& mean_array = model->GetArray(bn_op->inputs[1]);
+ const auto& multiplier_array = model->GetArray(bn_op->inputs[2]);
+ const auto& offset_array = model->GetArray(bn_op->inputs[3]);
+
+ CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) &&
+ IsConstantParameterArray(*model, bn_op->inputs[2]) &&
+ IsConstantParameterArray(*model, bn_op->inputs[3]))
+ << "Batch normalization resolution requires that mean, multiplier and "
+ "offset arrays be constant.";
+
+ // We should only have *float* BatchNormalizations... let's guard this
+ // assumption by CHECK's.
+ CHECK(mean_array.data_type == ArrayDataType::kFloat);
+ CHECK(multiplier_array.data_type == ArrayDataType::kFloat);
+ CHECK(offset_array.data_type == ArrayDataType::kFloat);
+
+ // Create the new Mul, Add operators
+ auto* mul_op = new MulOperator;
+ auto* add_op = new AddOperator;
+ const string mul_name =
+ AvailableArrayName(*model, bn_op->outputs[0] + "_mul");
+ const string add_name =
+ AvailableArrayName(*model, bn_op->outputs[0] + "_add");
+ const string mul_param_name = AvailableArrayName(*model, mul_name + "_param");
+ const string add_param_name = AvailableArrayName(*model, add_name + "_param");
+ mul_op->inputs = {bn_op->inputs[0], mul_param_name};
+ mul_op->outputs = {mul_name};
+ add_op->inputs = {mul_name, add_param_name};
+ add_op->outputs = {bn_op->outputs[0]};
+ AddMessageF("Splitting %s into %s and %s", LogName(*bn_op), LogName(*mul_op),
+ LogName(*add_op));
+
+ // Create the intermediate activation array (output of mul, input of add)
+ auto& intermediate_array = model->GetOrCreateArray(mul_op->outputs[0]);
+ intermediate_array.data_type = model->GetArray(bn_op->inputs[0]).data_type;
+
+ // Insert the new operators in the graph
+ auto add_it = model->operators.emplace(bn_it, add_op);
+ auto mul_it = model->operators.emplace(add_it, mul_op);
+ // update invalidated iterators.
+ DCHECK_EQ(mul_it->get(), mul_op);
+ add_it = mul_it + 1;
+ DCHECK_EQ(add_it->get(), add_op);
+ bn_it = add_it + 1;
+ DCHECK_EQ(bn_it->get(), bn_op);
+
+ // Create the new param arrays
+ const auto& mean_shape = mean_array.shape();
+ const auto& multiplier_shape = multiplier_array.shape();
+ const auto& offset_shape = offset_array.shape();
+ CHECK(mean_shape.dims() == multiplier_shape.dims());
+ CHECK(mean_shape.dims() == offset_shape.dims());
+ const auto& param_shape = mean_shape;
+ const int buffer_size = RequiredBufferSizeForShape(param_shape);
+ auto& mul_param_array = model->GetOrCreateArray(mul_param_name);
+ auto& add_param_array = model->GetOrCreateArray(add_param_name);
+ DropMinMax(model, mul_param_name);
+ DropMinMax(model, add_param_name);
+ mul_param_array.copy_shape(param_shape);
+ add_param_array.copy_shape(param_shape);
+ mul_param_array.data_type = ArrayDataType::kFloat;
+ add_param_array.data_type = ArrayDataType::kFloat;
+ auto& mul_float_data =
+ mul_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ auto& add_float_data =
+ add_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ mul_float_data.resize(buffer_size);
+ add_float_data.resize(buffer_size);
+ const auto& mean_float_data =
+ mean_array.GetBuffer<ArrayDataType::kFloat>().data;
+ const auto& multiplier_float_data =
+ multiplier_array.GetBuffer<ArrayDataType::kFloat>().data;
+ const auto& offset_float_data =
+ offset_array.GetBuffer<ArrayDataType::kFloat>().data;
+
+ CHECK(mul_float_data.size() == buffer_size);
+ CHECK(add_float_data.size() == buffer_size);
+ CHECK(mean_float_data.size() == buffer_size);
+ CHECK(multiplier_float_data.size() == buffer_size);
+ CHECK(offset_float_data.size() == buffer_size);
+
+ for (int i = 0; i < buffer_size; i++) {
+ mul_float_data[i] = multiplier_float_data[i];
+ add_float_data[i] =
+ offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i];
+ }
+
+ // Remove the old param arrays
+ model->arrays.erase(bn_op->inputs[1]);
+ model->arrays.erase(bn_op->inputs[2]);
+ model->arrays.erase(bn_op->inputs[3]);
+
+ // Remove the old operator
+ DCHECK_EQ(bn_it->get(), bn_op);
+ model->operators.erase(bn_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
new file mode 100644
index 0000000000..53e1be7a05
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
@@ -0,0 +1,247 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<bool> VectorGreaterThan(const std::vector<int>& a,
+ const std::vector<int>& b) {
+ DCHECK_EQ(a.size(), b.size());
+ const int size = a.size();
+ std::vector<bool> result(size);
+ for (int i = 0; i < size; i++) {
+ result[i] = a[i] > b[i];
+ }
+ return result;
+}
+
+void PairwiseVectorSelect(const std::vector<bool>& selector,
+ const std::vector<int>& input_a,
+ const std::vector<int>& input_b,
+ std::vector<int>* output_a,
+ std::vector<int>* output_b) {
+ DCHECK_EQ(input_a.size(), input_b.size());
+ DCHECK_EQ(output_a->size(), output_b->size());
+ DCHECK_EQ(input_a.size(), output_a->size());
+ DCHECK_EQ(selector.size(), input_a.size());
+ const int size = input_a.size();
+ for (int i = 0; i < size; i++) {
+ if (selector[i]) {
+ (*output_a)[i] = input_a[i];
+ (*output_b)[i] = input_b[i];
+ } else {
+ (*output_a)[i] = input_b[i];
+ (*output_b)[i] = input_a[i];
+ }
+ }
+}
+
+template <ArrayDataType InputsDataType, ArrayDataType OutputDataType>
+void EvaluateBinaryOperatorOnConstantInputs(Model* model,
+ const Operator* binary_op) {
+ CHECK(IsConstantParameterArray(*model, binary_op->inputs[0]));
+ CHECK(IsConstantParameterArray(*model, binary_op->inputs[1]));
+ CHECK(binary_op->fused_activation_function ==
+ FusedActivationFunctionType::kNone);
+ const auto& input0_array = model->GetArray(binary_op->inputs[0]);
+ const auto& input1_array = model->GetArray(binary_op->inputs[1]);
+ const auto& output_name = binary_op->outputs[0];
+ auto& output_array = model->GetArray(output_name);
+ CHECK(input0_array.data_type == InputsDataType);
+ CHECK(input1_array.data_type == InputsDataType);
+ CHECK(output_array.data_type == OutputDataType);
+
+ // We have already tested above for existence of input buffers
+ // (synonymous to being a constant param).
+ CHECK(input0_array.buffer);
+ CHECK(input1_array.buffer);
+ // On the other hand, the output should not already have a buffer.
+ CHECK(!output_array.buffer);
+
+ const auto& input0_data = input0_array.GetBuffer<InputsDataType>().data;
+ const auto& input1_data = input1_array.GetBuffer<InputsDataType>().data;
+ // Create the buffer on the output array, effectively turning it into
+ // a constant parameter
+
+ const Shape& output_shape = output_array.shape();
+ auto& output_data = output_array.GetMutableBuffer<OutputDataType>().data;
+ const int output_buffer_size = RequiredBufferSizeForShape(output_shape);
+ output_data.resize(output_buffer_size);
+ const int dims_count = output_shape.dimensions_count();
+
+ // It will be convenient here to have copies of the operands shapes
+ // extended to match the number of dimensions of the output shape.
+ Shape input0_shape = input0_array.shape();
+ Shape input1_shape = input1_array.shape();
+ ExtendShape(&input0_shape, dims_count);
+ ExtendShape(&input1_shape, dims_count);
+ // Now we may still have operands of different sizes, which would indicate
+ // that we have to "broadcast" the smaller dimension. We do this using a
+ // a vector of Booleans indicating which input is the larger in each
+ // dimension.
+ CHECK_EQ(input0_shape.dimensions_count(), input1_shape.dimensions_count());
+ CHECK_EQ(input0_shape.dimensions_count(), dims_count);
+ const std::vector<bool> input0_larger =
+ VectorGreaterThan(input0_shape.dims(), input1_shape.dims());
+
+ std::vector<int> big_sizes(dims_count);
+ std::vector<int> small_sizes(dims_count);
+ PairwiseVectorSelect(input0_larger, input0_shape.dims(), input1_shape.dims(),
+ &big_sizes, &small_sizes);
+
+ // The output should already be correctly sized to match the big dimensions.
+ for (int i = 0; i < dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), big_sizes[i]);
+ }
+
+ std::vector<int> input0_indices(dims_count);
+ std::vector<int> input1_indices(dims_count);
+ std::vector<int> modulo_indices(dims_count);
+
+ for (int k = 0; k < output_buffer_size; k++) {
+ const std::vector<int> output_indices = ReverseOffset(output_shape, k);
+ for (int i = 0; i < dims_count; i++) {
+ modulo_indices[i] = output_indices[i] % small_sizes[i];
+ }
+ PairwiseVectorSelect(input0_larger, output_indices, modulo_indices,
+ &input0_indices, &input1_indices);
+ const auto val0 = input0_data[Offset(input0_shape, input0_indices)];
+ const auto val1 = input1_data[Offset(input1_shape, input1_indices)];
+
+ DataType<OutputDataType> outval;
+ if (binary_op->type == OperatorType::kAdd) {
+ outval = val0 + val1;
+ } else if (binary_op->type == OperatorType::kMul) {
+ outval = val0 * val1;
+ } else if (binary_op->type == OperatorType::kSub) {
+ outval = val0 - val1;
+ } else if (binary_op->type == OperatorType::kDiv) {
+ outval = val0 / val1;
+ } else if (binary_op->type == OperatorType::kTensorFlowMinimum) {
+ outval = std::min(val0, val1);
+ } else if (binary_op->type == OperatorType::kTensorFlowMaximum) {
+ outval = std::max(val0, val1);
+ } else if (binary_op->type == OperatorType::kTensorFlowLess) {
+ outval = val0 < val1;
+ } else if (binary_op->type == OperatorType::kTensorFlowLessEqual) {
+ outval = val0 <= val1;
+ } else if (binary_op->type == OperatorType::kTensorFlowGreater) {
+ outval = val0 > val1;
+ } else if (binary_op->type == OperatorType::kTensorFlowGreaterEqual) {
+ outval = val0 >= val1;
+ } else {
+ LOG(FATAL) << "should not get here";
+ }
+ output_data[Offset(output_shape, output_indices)] = outval;
+ }
+}
+
+void EvaluateBinaryOperatorOnConstantInputs(Model* model,
+ const Operator* binary_op) {
+ const auto inputs_data_type = model->arrays[binary_op->inputs[0]]->data_type;
+ const auto output_data_type = model->arrays[binary_op->outputs[0]]->data_type;
+#define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \
+ if (inputs_data_type == InputsDataType && \
+ output_data_type == OutputDataType) { \
+ EvaluateBinaryOperatorOnConstantInputs<InputsDataType, OutputDataType>( \
+ model, binary_op); \
+ return; \
+ }
+ TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kFloat)
+ TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kBool)
+ TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kInt32)
+ TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kBool)
+ TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kInt64)
+ TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kBool)
+ LOG(FATAL) << "Unimplemented: don't know how to resolve a constant "
+ << "binary operator for these data types.";
+#undef TOCO_HANDLE_CASE
+}
+} // namespace
+
+bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ const auto* binary_op = binary_it->get();
+ // Test for binary ops of types that we know how to resolve
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv &&
+ binary_op->type != OperatorType::kTensorFlowMinimum &&
+ binary_op->type != OperatorType::kTensorFlowMaximum &&
+ binary_op->type != OperatorType::kTensorFlowLess &&
+ binary_op->type != OperatorType::kTensorFlowLessEqual &&
+ binary_op->type != OperatorType::kTensorFlowGreater &&
+ binary_op->type != OperatorType::kTensorFlowGreaterEqual) {
+ return false;
+ }
+ CHECK_EQ(binary_op->inputs.size(), 2);
+
+ const auto& input0_array = model->GetArray(binary_op->inputs[0]);
+ const auto& input1_array = model->GetArray(binary_op->inputs[1]);
+ // Check if both inputs are constant parameters.
+ if (!input0_array.buffer || !input1_array.buffer) {
+ return false;
+ }
+
+ auto& output_array = *model->arrays[binary_op->outputs[0]];
+ // Yield until the output array dims have been resolved.
+ if (!output_array.has_shape()) {
+ return false;
+ }
+
+ // At the moment we don't want to care about fused activation functions.
+ // The idea is that we should do the present constants-propagation before
+ // activation functions get fused.
+ if (binary_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ AddMessageF(
+ "Not resolving constant %s because it has a fused activation function",
+ LogName(*binary_op));
+ return false;
+ }
+
+ // Check that input data types agree.
+ CHECK(input0_array.data_type == input1_array.data_type);
+
+ // Do the actual constants propagation
+ EvaluateBinaryOperatorOnConstantInputs(model, binary_op);
+
+ // Remove the binary operator and its inputs
+ if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) {
+ model->arrays.erase(binary_op->inputs[0]);
+ }
+ if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) {
+ model->arrays.erase(binary_op->inputs[1]);
+ }
+ AddMessageF("Resolved constant %s to the equivalent constant array",
+ LogName(*binary_op));
+ model->operators.erase(binary_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
new file mode 100644
index 0000000000..0983c43849
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
@@ -0,0 +1,196 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+// Copies data from multiple source arrays to a destination array based on a
+// concatenation dimension. From each array in input_arrays, it copies chunk
+// sizes provided in array_copy_size vector (per array). It uses the buffer
+// in concatenated_array as destination buffer.
+template <ArrayDataType A, typename T>
+void CopyTensorSegments(const std::vector<Array*>& input_arrays,
+ const std::vector<int>& array_copy_size,
+ const int num_elements_concatenated_array,
+ Array* concatenated_array) {
+ for (Array* input_array : input_arrays) {
+ if (!input_array->buffer) {
+ return;
+ }
+ }
+
+ auto& concatenated_array_buffer =
+ concatenated_array->GetMutableBuffer<A>().data;
+ concatenated_array_buffer.resize(num_elements_concatenated_array);
+
+ // It does not matter which array to use to find the value for the total
+ // number of copy steps.
+ CHECK(!input_arrays.empty());
+ CHECK_NE(array_copy_size[0], 0);
+ const int total_copy_steps =
+ input_arrays[0]->GetBuffer<A>().data.size() / array_copy_size[0];
+
+ // Initialize the source pointers to point to beginning of the array buffers.
+ std::vector<const T*> src_ptr;
+ src_ptr.reserve(input_arrays.size());
+ for (Array* input_array : input_arrays) {
+ src_ptr.push_back(input_array->GetBuffer<A>().data.data());
+ }
+
+ // Copy the data from input_arrays to concatenated_array_buffer.
+ T* dest_ptr = concatenated_array_buffer.data();
+ for (int s = 0; s < total_copy_steps; s++) {
+ for (int i = 0; i < input_arrays.size(); i++) {
+ std::copy(src_ptr[i], src_ptr[i] + array_copy_size[i], dest_ptr);
+ src_ptr[i] += array_copy_size[i];
+ dest_ptr += array_copy_size[i];
+ }
+ }
+}
+
+// Receives a series of input arrays of type Array and an integer showing the
+// axis on which those arrays will be concatenated. It returns the concatenated
+// arrray.
+template <ArrayDataType A>
+void ConcatenateTensorBuffers(const std::vector<Array*>& input_arrays,
+ int concatenation_axis,
+ Array* concatenated_array) {
+ int num_elements_concatenated_array = 1;
+ for (int i = 0; i < concatenated_array->shape().dimensions_count(); i++) {
+ num_elements_concatenated_array *= concatenated_array->shape().dims()[i];
+ }
+ // Prepare the data needed for segmented copy from multiple source arrays to
+ // a destination array based on a oncatenation dimension.
+ std::vector<int> array_copy_size(input_arrays.size());
+ int count = 0;
+ for (Array* input_array : input_arrays) {
+ const Shape array_shape = input_array->shape();
+ array_copy_size[count] = 1;
+ for (int i = concatenation_axis; i < array_shape.dimensions_count(); i++) {
+ array_copy_size[count] *= array_shape.dims()[i];
+ }
+ count++;
+ }
+
+ // Do the actual data copy.
+ CopyTensorSegments<A, DataType<A>>(input_arrays, array_copy_size,
+ num_elements_concatenated_array,
+ concatenated_array);
+}
+
+// Sets the minimum and maximum values for the concatenated array. If it's
+// already set (e.g. because of previous pass in TOCO), it doesn't change it and
+// returns. Otherwise it uses the input arrays min and max values to compute the
+// concatenated array min and max.
+void SetMinMaxForConcatenedArray(const std::vector<Array*>& input_arrays,
+ Array* concatenated_array) {
+ CHECK(concatenated_array->data_type == ArrayDataType::kFloat);
+ // If the minmax is already set, use it
+ if (concatenated_array->minmax) return;
+
+ double concat_min = std::numeric_limits<double>::infinity();
+ double concat_max = -std::numeric_limits<double>::infinity();
+
+ for (Array* input_array : input_arrays) {
+ // If any of the input arrays minmax is not set, return.
+ // TODO(ghodrat): shall we add the logic to compute the minmax?
+ if (!input_array->minmax) return;
+ const MinMax& input_minmax = input_array->GetMinMax();
+ concat_min = std::min(concat_min, input_minmax.min);
+ concat_max = std::max(concat_max, input_minmax.max);
+ }
+ MinMax& minmax = concatenated_array->GetOrCreateMinMax();
+ minmax.min = concat_min;
+ minmax.max = concat_max;
+}
+
+} // namespace
+
+// Resolves the concatenation operator if all its inputs are constant arrays.
+bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
+ const auto concat_it = model->operators.begin() + op_index;
+ const auto* concat_base_op = concat_it->get();
+ if (concat_base_op->type != OperatorType::kConcatenation) {
+ return false;
+ }
+ const auto* concat_op =
+ static_cast<const ConcatenationOperator*>(concat_base_op);
+
+ for (const string& input_name : concat_op->inputs) {
+ // We only expect constant unquantized arrays as input, otherwise we return.
+ // We also make sure the shapes of the input arrays are known and they are
+ // all discardable.
+ const Operator* input_op = GetOpWithOutput(*model, input_name);
+ if (input_op) return false;
+ if (!IsConstantParameterArray(*model, input_name)) return false;
+ if (!model->GetArray(input_name).has_shape()) return false;
+ if (model->GetArray(input_name).quantization_params) return false;
+ if (!IsDiscardableArray(*model, input_name)) return false;
+ }
+
+ const int concatenation_axis = concat_op->concat_dim;
+
+ CHECK_EQ(concat_op->outputs.size(), 1);
+ string concatenated_array_name = concat_op->outputs[0];
+ Array& concatenated_array = model->GetOrCreateArray(concatenated_array_name);
+ std::vector<Array*> input_arrays;
+ for (const string& input_name : concat_op->inputs) {
+ input_arrays.push_back(&model->GetArray(input_name));
+ }
+
+ switch (concatenated_array.data_type) {
+ case ArrayDataType::kFloat:
+ ConcatenateTensorBuffers<ArrayDataType::kFloat>(
+ input_arrays, concatenation_axis, &concatenated_array);
+ SetMinMaxForConcatenedArray(input_arrays, &concatenated_array);
+ break;
+ case ArrayDataType::kUint8:
+ ConcatenateTensorBuffers<ArrayDataType::kUint8>(
+ input_arrays, concatenation_axis, &concatenated_array);
+ break;
+ case ArrayDataType::kInt32:
+ ConcatenateTensorBuffers<ArrayDataType::kInt32>(
+ input_arrays, concatenation_axis, &concatenated_array);
+ break;
+ case ArrayDataType::kInt64:
+ ConcatenateTensorBuffers<ArrayDataType::kInt64>(
+ input_arrays, concatenation_axis, &concatenated_array);
+ break;
+ default:
+ LOG(FATAL) << "ArrayDataType not supported";
+ }
+
+ // Remove all the resolved arrays.
+ for (const string& input_name : concat_op->inputs) {
+ model->arrays.erase(input_name);
+ }
+
+ // Remove concatenate operator
+ model->operators.erase(concat_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
new file mode 100644
index 0000000000..244adcc4c4
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
+ const auto fakequant_it = model->operators.begin() + op_index;
+ const auto* fakequant_base_op = fakequant_it->get();
+ if (fakequant_base_op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+
+ const auto* fakequant_op =
+ static_cast<const FakeQuantOperator*>(fakequant_base_op);
+
+ // Yield until the fakequant MinMax has been resolved.
+ if (!fakequant_op->minmax) {
+ return false;
+ }
+
+ // This transformation only applies when the input array is constant.
+ if (!IsConstantParameterArray(*model, fakequant_op->inputs[0])) {
+ return false;
+ }
+
+ const auto& input_array = model->GetArray(fakequant_op->inputs[0]);
+ auto& output_array = model->GetArray(fakequant_op->outputs[0]);
+ CHECK(input_array.data_type == ArrayDataType::kFloat);
+ output_array.data_type = ArrayDataType::kFloat;
+ CHECK(!output_array.buffer);
+ const auto& input_buffer = input_array.GetBuffer<ArrayDataType::kFloat>();
+ auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ const int size = input_buffer.data.size();
+ output_buffer.data.resize(size);
+ QuantizationParams qparams;
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(
+ model->flags, *fakequant_op->minmax, &qparams);
+ for (int i = 0; i < size; i++) {
+ const double src_val = input_buffer.data[i];
+ const double unclamped_quantized_val =
+ std::round(qparams.zero_point + src_val / qparams.scale);
+ const double quantized_val =
+ std::min(255., std::max(0., unclamped_quantized_val));
+ const double dst_val = qparams.scale * (quantized_val - qparams.zero_point);
+ output_buffer.data[i] = dst_val;
+ }
+ if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
+ model->arrays.erase(fakequant_op->inputs[0]);
+ }
+ model->operators.erase(fakequant_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc
new file mode 100644
index 0000000000..8cc6db1619
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveConstantTensorFlowShape::Run(Model* model, std::size_t op_index) {
+ const auto tfshape_it = model->operators.begin() + op_index;
+ const auto* tfshape_base_op = tfshape_it->get();
+ if (tfshape_base_op->type != OperatorType::kTensorFlowShape) {
+ return false;
+ }
+
+ const auto* tfshape_op =
+ static_cast<const TensorFlowShapeOperator*>(tfshape_base_op);
+
+ const auto& input_array = model->GetArray(tfshape_op->inputs[0]);
+ auto& output_array = model->GetArray(tfshape_op->outputs[0]);
+
+ // Yield until the input array's shape has been resolved.
+ if (!input_array.has_shape()) {
+ return false;
+ }
+
+ // Create a buffer for the output array, making it a constant array, and
+ // copy the input shape into the output buffer.
+ CHECK(!output_array.buffer);
+ auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kInt32>();
+ output_buffer.data = input_array.shape().dims();
+
+ // Erase the input array if no longer used
+ if (IsDiscardableArray(*model, tfshape_op->inputs[0]) &&
+ CountOpsWithInput(*model, tfshape_op->inputs[0]) == 1) {
+ model->arrays.erase(tfshape_op->inputs[0]);
+ }
+ model->operators.erase(tfshape_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
new file mode 100644
index 0000000000..bb9bda3c82
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -0,0 +1,175 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <algorithm>
+#include <cmath>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
+ const auto unary_it = model->operators.begin() + op_index;
+ const auto* unary_op = unary_it->get();
+ // Test for unary ops of types that we know how to resolve
+ if (unary_op->type != OperatorType::kTensorFlowRsqrt &&
+ unary_op->type != OperatorType::kTensorFlowSqrt &&
+ unary_op->type != OperatorType::kTensorFlowSquare &&
+ unary_op->type != OperatorType::kTensorFlowSum &&
+ unary_op->type != OperatorType::kTensorFlowMin &&
+ unary_op->type != OperatorType::kTensorFlowMax &&
+ unary_op->type != OperatorType::kTensorFlowReshape) {
+ return false;
+ }
+ // Check if the input is a constant parameter.
+ if (!IsConstantParameterArray(*model, unary_op->inputs[0])) {
+ return false;
+ }
+
+ // if the unary op involves a tensor required by a rnn state, ignore it
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (unary_op->inputs[0] == rnn_state.back_edge_source_array()) {
+ return false;
+ }
+ if (unary_op->inputs[0] == rnn_state.state_array()) {
+ return false;
+ }
+ }
+
+ // At the moment we don't want to care about fused activation functions.
+ // The idea is that we should do the present constants-propagation before
+ // activation functions get fused.
+ if (unary_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ AddMessageF(
+ "Not resolving constant %s "
+ " because it has a fused activation function",
+ LogName(*unary_op));
+ return false;
+ }
+ const auto& input_array = model->GetArray(unary_op->inputs[0]);
+ // We have already tested above for existence of buffers (synonymous to being
+ // a constant param).
+ CHECK(input_array.buffer);
+ // At the moment we only support float buffers.
+ if (input_array.buffer->type != ArrayDataType::kFloat) {
+ return false;
+ }
+ const auto& input_float_data =
+ input_array.GetBuffer<ArrayDataType::kFloat>().data;
+ // Create the float buffer on the output array, effectively turning it into
+ // a constant parameter
+ const auto& output_name = unary_op->outputs[0];
+ auto& output_array = model->GetArray(output_name);
+ // Yield until the output array dims have been resolved.
+ if (!output_array.has_shape()) {
+ return false;
+ }
+
+ int input_buffer_size = RequiredBufferSizeForShape(input_array.shape());
+ int output_buffer_size = RequiredBufferSizeForShape(output_array.shape());
+ const Shape& input_shape = input_array.shape();
+ const Shape& output_shape = output_array.shape();
+
+ auto& output_float_data =
+ output_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ output_float_data.resize(output_buffer_size);
+
+ const int output_dims_count = output_shape.dimensions_count();
+ if (unary_op->type == OperatorType::kTensorFlowReshape) {
+ CHECK(input_buffer_size == output_buffer_size);
+ memcpy(output_float_data.data(), input_float_data.data(),
+ input_buffer_size * sizeof(input_float_data[0]));
+ } else if (unary_op->type == OperatorType::kTensorFlowSum) {
+ // At the moment only full reduction across all dimensions is supported.
+ for (int i = 0; i < output_dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), 1);
+ }
+ float sum = 0.f;
+ const int input_size = RequiredBufferSizeForShape(input_shape);
+ for (int i = 0; i < input_size; i++) {
+ sum += input_float_data[i];
+ }
+ output_float_data[0] = sum;
+ } else if (unary_op->type == OperatorType::kTensorFlowMin) {
+ // At the moment only full reduction across all dimensions is supported.
+ // TODO(starka): Output should not be padded.
+ for (int i = 0; i < output_dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), 1);
+ }
+ float min = input_float_data[0];
+ const int input_size = RequiredBufferSizeForShape(input_shape);
+ for (int i = 0; i < input_size; i++) {
+ min = std::min(min, input_float_data[i]);
+ }
+ output_float_data[0] = min;
+ } else if (unary_op->type == OperatorType::kTensorFlowMax) {
+ // At the moment only full reduction across all dimensions is supported.
+ // TODO(starka): Output should not be padded.
+ for (int i = 0; i < output_dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), 1);
+ }
+ float max = input_float_data[0];
+ const int input_size = RequiredBufferSizeForShape(input_shape);
+ for (int i = 0; i < input_size; i++) {
+ max = std::max(max, input_float_data[i]);
+ }
+ output_float_data[0] = max;
+ } else if (unary_op->type == OperatorType::kTensorFlowRsqrt ||
+ unary_op->type == OperatorType::kTensorFlowSqrt ||
+ unary_op->type == OperatorType::kTensorFlowSquare) {
+ // Element-wise ops. Should have perfectly matching sizes here.
+ const int input_size = RequiredBufferSizeForShape(input_shape);
+ for (int i = 0; i < output_dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), input_shape.dims(i));
+ }
+
+ for (int i = 0; i < input_size; i++) {
+ const float val = input_float_data[i];
+ float outval = 0.f;
+ if (unary_op->type == OperatorType::kTensorFlowRsqrt) {
+ outval = 1.0f / std::sqrt(val);
+ } else if (unary_op->type == OperatorType::kTensorFlowSqrt) {
+ outval = std::sqrt(val);
+ } else if (unary_op->type == OperatorType::kTensorFlowSquare) {
+ outval = val * val;
+ } else {
+ LOG(FATAL) << "should not get here.";
+ }
+ output_float_data[i] = outval;
+ }
+ } else {
+ LOG(FATAL) << "should not get here.";
+ }
+ for (const auto& input : unary_op->inputs) {
+ if (CountOpsWithInput(*model, input) == 1) {
+ model->arrays.erase(input);
+ }
+ }
+ AddMessageF("Resolved constant %s to the equivalent constant array",
+ LogName(*unary_op));
+ model->operators.erase(unary_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc
new file mode 100644
index 0000000000..d25c773f19
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc
@@ -0,0 +1,51 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) {
+ auto* mean_op = model->operators[op_index].get();
+ if (mean_op->type != OperatorType::kMean) return false;
+ auto* op = static_cast<MeanOperator*>(mean_op);
+
+ if (!op->reduction_indices.empty()) return false;
+ if (op->inputs.size() != 2) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+
+ const auto& indices_array = *model->arrays[op->inputs[1]];
+ if (!indices_array.has_shape()) return false;
+
+ op->reduction_indices = indices_array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ // At the moment, we only support simultaneous reduction over width and
+ // height. This is mainly limited by the fact that currently, the runtime
+ // arrays are always 4-dimensional.
+ CHECK_EQ(op->reduction_indices.size(), 2);
+ CHECK((op->reduction_indices[0] == 1 && op->reduction_indices[1] == 2) ||
+ (op->reduction_indices[0] == 2 && op->reduction_indices[1] == 1));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
new file mode 100644
index 0000000000..d5f5869c62
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) {
+ const auto pad_it = model->operators.begin() + op_index;
+ auto* pad_op = pad_it->get();
+ if (pad_op->type != OperatorType::kPad) return false;
+
+ auto* op = static_cast<PadOperator*>(pad_op);
+ if (!op->left_padding.empty()) return false;
+
+ CHECK_EQ(op->inputs.size(), 2);
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+
+ const auto& array = *model->arrays[op->inputs[1]];
+ if (!array.has_shape()) return false;
+
+ const std::vector<int>& dims = array.shape().dims();
+ CHECK_EQ(dims.size(), 2);
+
+ std::vector<int> buffer = array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ for (int i = 0; i < dims[0]; ++i) {
+ op->left_padding.push_back(buffer[i * 2]);
+ op->right_padding.push_back(buffer[i * 2 + 1]);
+ }
+
+ // TODO(dkalenichenko): Delete the extra input?
+
+ return true;
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
new file mode 100644
index 0000000000..8fa7b83bed
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
@@ -0,0 +1,93 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
+ auto reorder_it = model->operators.begin() + op_index;
+ auto* reorder_op = static_cast<ReorderAxesOperator*>(reorder_it->get());
+ if (reorder_op->type != OperatorType::kReorderAxes) {
+ return false;
+ }
+ const auto& input_array_name = reorder_op->inputs[0];
+ const auto& output_array_name = reorder_op->outputs[0];
+ auto& input_array = model->GetArray(input_array_name);
+ auto& output_array = model->GetArray(output_array_name);
+ string constant_input_array_name = input_array_name;
+ if (!input_array.buffer) {
+ const auto* op_producing_input = GetOpWithOutput(*model, input_array_name);
+ if (op_producing_input &&
+ op_producing_input->type == OperatorType::kFakeQuant) {
+ constant_input_array_name = op_producing_input->inputs[0];
+ }
+ }
+ auto& constant_input_array = model->GetArray(constant_input_array_name);
+ if (!constant_input_array.buffer) {
+ return false;
+ }
+ // Yield until output dims have been resolved.
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ // Reorder the input array dims and buffer data
+ CHECK(constant_input_array.buffer->type == ArrayDataType::kFloat);
+ CHECK(!output_array.buffer);
+ auto& input_data =
+ constant_input_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ std::vector<float> reordered_data;
+ reordered_data.resize(RequiredBufferSizeForShape(output_array.shape()));
+ const auto input_axes_order = reorder_op->input_axes_order;
+ const auto output_axes_order = reorder_op->output_axes_order;
+ // TODO(b/62904716) Shapes should be used directly.
+ Shape input_shape = constant_input_array.shape();
+ Shape output_shape = output_array.shape();
+ if (AxesCount(input_axes_order) == 2) {
+ UnextendShape(&input_shape, 2);
+ UnextendShape(&output_shape, 2);
+ }
+ ShuffleArray(input_shape, input_axes_order, output_axes_order, output_shape,
+ input_data.data(), reordered_data.data());
+ input_data = reordered_data;
+ input_array.copy_shape(output_array.shape());
+ constant_input_array.copy_shape(output_array.shape());
+
+ // Update the edges of the graph to point to the input array
+ for (const auto& other_op : model->operators) {
+ for (auto& input : other_op->inputs) {
+ if (input == output_array_name) {
+ input = input_array_name;
+ }
+ }
+ }
+
+ AddMessageF("Reordered axes for array %s", input_array_name);
+
+ // Remove the op and output array.
+ model->arrays.erase(output_array_name);
+ model->operators.erase(reorder_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
new file mode 100644
index 0000000000..bed2a85bd2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
@@ -0,0 +1,49 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) {
+ const auto reshape_it = model->operators.begin() + op_index;
+ auto* reshape_op = reshape_it->get();
+ if (reshape_op->type != OperatorType::kTensorFlowReshape) {
+ return false;
+ }
+
+ auto* op = static_cast<TensorFlowReshapeOperator*>(reshape_op);
+
+ if (!op->shape.empty()) return false;
+
+ if (IsConstantParameterArray(*model, reshape_op->inputs[1])) {
+ const auto& constant_input_array = *model->arrays[reshape_op->inputs[1]];
+ op->shape = constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
+ }
+
+ if (op->shape.empty()) return false;
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
new file mode 100644
index 0000000000..1d0a2ec8f6
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
@@ -0,0 +1,52 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) {
+ const auto slice_it = model->operators.begin() + op_index;
+ auto* slice_op = slice_it->get();
+ if (slice_op->type != OperatorType::kSlice) return false;
+
+ auto* op = static_cast<SliceOperator*>(slice_op);
+ if (!op->begin.empty()) return false;
+
+ CHECK_EQ(op->inputs.size(), 3);
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
+
+ const auto& begin_array = *model->arrays[op->inputs[1]];
+ if (!begin_array.has_shape()) return false;
+
+ const auto& size_array = *model->arrays[op->inputs[2]];
+ if (!size_array.has_shape()) return false;
+
+ op->begin = begin_array.GetBuffer<ArrayDataType::kInt32>().data;
+ op->size = size_array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ // TODO(dkalenichenko): Delete the extra inputs?
+
+ return true;
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
new file mode 100644
index 0000000000..5fc3b25bc1
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
+ const auto slice_it = model->operators.begin() + op_index;
+ auto* slice_op = slice_it->get();
+ if (slice_op->type != OperatorType::kStridedSlice) return false;
+
+ auto* op = static_cast<StridedSliceOperator*>(slice_op);
+ if (!op->start_indices.empty()) return false;
+
+ CHECK_EQ(op->inputs.size(), 4);
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[3])) return false;
+
+ const auto& start_array = *model->arrays[op->inputs[1]];
+ if (!start_array.has_shape()) return false;
+
+ const auto& stop_array = *model->arrays[op->inputs[2]];
+ if (!stop_array.has_shape()) return false;
+
+ const auto& stride_array = *model->arrays[op->inputs[3]];
+ if (!stride_array.has_shape()) return false;
+
+ op->start_indices = start_array.GetBuffer<ArrayDataType::kInt32>().data;
+ op->stop_indices = stop_array.GetBuffer<ArrayDataType::kInt32>().data;
+ op->strides = stride_array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ // Only 4D arrays are supported for now.
+ CHECK_EQ(op->start_indices.size(), 4);
+ CHECK_EQ(op->stop_indices.size(), 4);
+ CHECK_EQ(op->strides.size(), 4);
+
+ // TODO(dkalenichenko): Delete the extra inputs?
+
+ return true;
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
new file mode 100644
index 0000000000..b482f5cf51
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
@@ -0,0 +1,86 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
+ auto concat_it = model->operators.begin() + op_index;
+ const auto* tf_concat_op = concat_it->get();
+ if (tf_concat_op->type != OperatorType::kTensorFlowConcat &&
+ tf_concat_op->type != OperatorType::kTensorFlowConcatV2) {
+ return false;
+ }
+
+ CHECK_GE(tf_concat_op->inputs.size(), 2);
+ // TensorFlow Concat and ConcatV2 nodes only differ by the ordering
+ // of inputs: in Concat, the concat_dim is the first input, while in
+ // ConcatV2, it is the last input.
+ std::size_t concat_dim_pos = 0;
+ if (tf_concat_op->type == OperatorType::kTensorFlowConcatV2) {
+ concat_dim_pos = tf_concat_op->inputs.size() - 1;
+ }
+ const string concat_dim_name = tf_concat_op->inputs[concat_dim_pos];
+ std::vector<string> concat_input_names;
+ for (std::size_t i = 0; i < tf_concat_op->inputs.size(); i++) {
+ if (i != concat_dim_pos) {
+ concat_input_names.push_back(tf_concat_op->inputs[i]);
+ }
+ }
+ // If the concat_dim array hasn't been resolved to a constant yet,
+ // we need to yield.
+ const auto& concat_dim_array = model->GetArray(concat_dim_name);
+ if (!concat_dim_array.buffer) {
+ AddMessageF("Waiting for the concat_dim of %s to be resolved to a constant",
+ LogName(*tf_concat_op));
+ return false;
+ }
+
+ CHECK(concat_dim_array.data_type == ArrayDataType::kInt32);
+ const auto& concat_dim_data =
+ concat_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
+ CHECK_EQ(concat_dim_data.size(), 1);
+ const int concat_dim = concat_dim_data[0];
+
+ // Create the Concatenation op replacing the TensorFlowConcat op.
+ auto* concatenation_op = new ConcatenationOperator;
+ concatenation_op->concat_dim = concat_dim;
+ concatenation_op->inputs = concat_input_names;
+ concatenation_op->outputs = {tf_concat_op->outputs[0]};
+ auto depth_concat_it = model->operators.emplace(concat_it, concatenation_op);
+ CHECK_EQ(depth_concat_it->get(), concatenation_op);
+ // Update invalidated iterator
+ concat_it = depth_concat_it + 1;
+ CHECK_EQ(concat_it->get(), tf_concat_op);
+
+ // Remove the concat_dim array if it is not used by anything else.
+ if (CountOpsWithInput(*model, concat_dim_name) == 1) {
+ model->arrays.erase(concat_dim_name);
+ }
+ // Remove the TensorFlowConcat op
+ model->operators.erase(concat_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
new file mode 100644
index 0000000000..bea7487051
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -0,0 +1,106 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
+ auto matmul_it = model->operators.begin() + op_index;
+ if (matmul_it->get()->type != OperatorType::kTensorFlowMatMul) {
+ return false;
+ }
+ const auto* matmul_op = matmul_it->get();
+
+ // Find the op producing the array passed to this MatMul
+ auto previous_op_it = model->operators.begin();
+ bool found = false;
+ for (; previous_op_it != model->operators.end(); ++previous_op_it) {
+ for (const auto& output : (*previous_op_it)->outputs) {
+ if (output == matmul_op->inputs[0]) {
+ found = true;
+ break;
+ }
+ }
+ if (found) {
+ break;
+ }
+ }
+ Operator* previous_op = (found) ? previous_op_it->get() : nullptr;
+
+ // construct the new FullyConnectedOperator
+ auto* fc_op = new FullyConnectedOperator;
+ fc_op->outputs = matmul_op->outputs;
+
+ // insert the newly constructed FullyConnectedOperator
+ auto fc_it = model->operators.emplace(matmul_it, fc_op);
+
+ // refresh invalidated iterator
+ matmul_it = fc_it + 1;
+ DCHECK_EQ(matmul_it->get(), matmul_op);
+
+ // The way that TensorFlow encodes FullyConnected ops is as a pair
+ // (Reshape, MatMul), so we want to remove the Reshape op and rewrite the
+ // MatMul
+ // op as a FullyConnected. However, TensorFlow skips the Reshape ops if the
+ // input doesn't need reshaping, so we can't just match (Reshape, MatMul)
+ // pairs.
+ if (previous_op && previous_op->type == OperatorType::kTensorFlowReshape) {
+ AddMessageF("Combining %s and %s into %s", LogName(*previous_op),
+ LogName(*matmul_op), LogName(*fc_op));
+ const auto& previous_op_output = previous_op->outputs[0];
+ if (CountOpsWithInput(*model, previous_op_output) == 1) {
+ model->arrays.erase(previous_op_output);
+ }
+ CHECK_EQ(previous_op->inputs.size(), 2);
+ fc_op->inputs = {previous_op->inputs[0], matmul_op->inputs[1]};
+ // Only remove Reshape node if no other node uses its output.
+ if (CountOpsWithInput(*model, previous_op_output) == 1) {
+ const auto& previous_op_shape = previous_op->inputs[1];
+ if (CountOpsWithInput(*model, previous_op_shape) == 1 &&
+ !GetOpWithOutput(*model, previous_op_shape)) {
+ model->arrays.erase(previous_op_shape);
+ }
+ model->operators.erase(previous_op_it);
+ }
+
+ // We may have just invalidated matmul_it, so let's refresh it now.
+ matmul_it = model->operators.begin();
+ for (; matmul_it != model->operators.end(); ++matmul_it) {
+ if (matmul_it->get() == matmul_op) {
+ break;
+ }
+ }
+ CHECK(matmul_it != model->operators.end());
+ CHECK(matmul_it->get() == matmul_op);
+ } else {
+ AddMessageF("Replacing %s by a FullyConnected operator",
+ LogName(*matmul_op));
+ fc_op->inputs = {matmul_op->inputs[0], matmul_op->inputs[1]};
+ }
+
+ // erase the MatMul operator
+ model->operators.erase(matmul_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
new file mode 100644
index 0000000000..cfa5ce0716
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
+ const auto merge_it = model->operators.begin() + op_index;
+ const auto* merge_op = merge_it->get();
+ if (merge_op->type != OperatorType::kTensorFlowMerge) {
+ return false;
+ }
+
+ // We need to yield until this Merge node has only 1 input, which will mean
+ // that that is the selected input. Other graph transformations on other nodes
+ // such as ResolveTensorFlowSwitch, will take care of trimming the
+ // non-selected inputs, so that at some point there will be only 1 input left.
+ if (merge_op->inputs.size() > 1) {
+ AddMessageF("Waiting for %s to be resolved", LogName(*merge_op));
+ return false;
+ }
+
+ // Now that the merge node has 1 input exactly, it is the same as an Identity
+ // node and can be resolved trivially.
+ CHECK_EQ(merge_op->inputs.size(), 1);
+
+ // Update the edges of the graph ahead of removing the node.
+ for (const auto& other_op : model->operators) {
+ for (auto& input : other_op->inputs) {
+ if (input == merge_op->outputs[0]) {
+ input = merge_op->inputs[0];
+ }
+ }
+ }
+
+ // Remove the node and its output array.
+ AddMessageF("Removing already-resolved %s", LogName(*merge_op));
+ model->arrays.erase(merge_op->outputs[0]);
+ model->operators.erase(merge_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc
new file mode 100644
index 0000000000..1d3f42b5ec
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowSqueeze::Run(Model* model, std::size_t op_index) {
+ const auto squeeze_it = model->operators.begin() + op_index;
+ const auto* squeeze_op = squeeze_it->get();
+ if (squeeze_op->type != OperatorType::kSqueeze) {
+ return false;
+ }
+
+ CHECK_EQ(squeeze_op->inputs.size(), 1);
+ CHECK_EQ(squeeze_op->outputs.size(), 1);
+
+ // If the output is consumed by a reshape op, it's a trivial squeeze.
+ if (CountOpsWithInput(*model, squeeze_op->outputs[0]) == 1) {
+ const auto* next_op = GetOpWithInput(*model, squeeze_op->outputs[0]);
+ if (next_op->type == OperatorType::kTensorFlowReshape) {
+ AddMessageF(
+ "%s is trivial because its output is only consumed by a "
+ "Reshape op",
+ LogName(*squeeze_op));
+
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+ }
+ }
+
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
new file mode 100644
index 0000000000..55adfca037
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
@@ -0,0 +1,123 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
+ const auto switch_it = model->operators.begin() + op_index;
+ const auto* switch_op = switch_it->get();
+ if (switch_op->type != OperatorType::kTensorFlowSwitch) {
+ return false;
+ }
+
+ CHECK_EQ(switch_op->inputs.size(), 2);
+ CHECK_EQ(switch_op->outputs.size(), 2);
+ const string& predicate_name = switch_op->inputs[1];
+ // If the predicate array hasn't been resolved to a constant yet,
+ // we need to yield.
+ if (!IsConstantParameterArray(*model, predicate_name)) {
+ AddMessageF(
+ "Waiting for the boolean predicate of %s to be resolved to a constant",
+ LogName(*switch_op));
+ return false;
+ }
+
+ // The predicate should be boolean, and should consist of a single value.
+ const auto& predicate_array = model->GetArray(predicate_name);
+ CHECK(predicate_array.data_type == ArrayDataType::kBool);
+ for (const auto& dim : predicate_array.shape().dims()) {
+ CHECK_EQ(dim, 1);
+ }
+
+ // Obtain the predicate boolean value.
+ const auto& predicate_data =
+ predicate_array.GetBuffer<ArrayDataType::kBool>().data;
+ CHECK_EQ(predicate_data.size(), 1);
+ const bool predicate_value = predicate_data[0];
+
+ // From the TensorFlow docs on .switch() in
+ // third_party/tensorflow/python/ops/control_flow_ops.py
+ //
+ // If `pred` is false, the `data` input is forwared to the first output.
+ // Otherwise, the data goes to the second output.
+ //
+ // Note that this comment used to say the opposite and was recently fixed:
+ // https://github.com/tensorflow/tensorflow/commit/bc456e361d49d1d89a74b80060c70efb51fd7d87#diff-76ab9dafbe12c20ddc3769c6b108986c
+ const int selected_output_index = predicate_value ? 1 : 0;
+ const int nonselected_output_index = predicate_value ? 0 : 1;
+
+ // Update the edges of the graph ahead of removing the node:
+ // edges that were pointing to the selected output, should instead
+ // point to the input of the Switch node.
+ for (const auto& other_op : model->operators) {
+ for (auto& input : other_op->inputs) {
+ if (input == switch_op->outputs[selected_output_index]) {
+ input = switch_op->inputs[0];
+ }
+ }
+ }
+
+ // There remains to handle the edges that were pointing to the nonselected
+ // output. We will just discard those edges. Concretely, at the moment,
+ // our only examples of graphs with Switch nodes have them feeding into Merge
+ // nodes, so what we're saying here is that we'll make the convention,
+ // in our toco internal representation, that Merge nodes with only 1 input
+ // are Merge nodes that have been resolved already and should be have as
+ // Identity nodes, simply forwarding their input.
+ //
+ for (const auto& other_op : model->operators) {
+ auto input_it = other_op->inputs.begin();
+ while (input_it != other_op->inputs.end()) {
+ if (*input_it == switch_op->outputs[nonselected_output_index]) {
+ // Let us guard our assumption that only Merge nodes consume the outputs
+ // of Switch nodes:
+ CHECK(other_op->type == OperatorType::kTensorFlowMerge);
+ input_it = other_op->inputs.erase(input_it);
+ } else {
+ ++input_it;
+ }
+ }
+ }
+
+ // Remove the output arrays if they are now unused.
+ for (int i = 0; i < 2; i++) {
+ if (!GetOpWithInput(*model, switch_op->outputs[i])) {
+ model->arrays.erase(switch_op->outputs[i]);
+ }
+ }
+ // Remove input arrays if they are only used by the switch itself and aren't
+ // the output of another op (will get handled by RemoveUnusedOp in that case).
+ for (const auto& input : switch_op->inputs) {
+ if (CountOpsWithInput(*model, input) == 1 &&
+ !GetOpWithOutput(*model, input)) {
+ model->arrays.erase(input);
+ }
+ }
+ // Remove the switch node itself.
+ AddMessageF("Removing already-resolved %s", LogName(*switch_op));
+ model->operators.erase(switch_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
new file mode 100644
index 0000000000..9f7e7c42a2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
@@ -0,0 +1,97 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op,
+ int operand_index) {
+ CHECK(tile_op->type == OperatorType::kTensorFlowTile);
+ CHECK_EQ(binary_op->inputs.size(), 2);
+ CHECK_EQ(tile_op->inputs.size(), 2);
+ const string tile_multiplier_array = tile_op->inputs[1];
+ const string tile_output_array = tile_op->outputs[0];
+ binary_op->inputs[operand_index] = tile_op->inputs[0];
+ auto tile_it = model->operators.begin();
+ for (; tile_it != model->operators.end(); ++tile_it) {
+ if (tile_it->get() == tile_op) {
+ break;
+ }
+ }
+ CHECK(tile_it != model->operators.end());
+ CHECK(tile_it->get() == tile_op);
+ model->operators.erase(tile_it);
+ if (!CountOpsWithInput(*model, tile_multiplier_array) &&
+ !GetOpWithOutput(*model, tile_multiplier_array)) {
+ model->arrays.erase(tile_multiplier_array);
+ }
+ if (!CountOpsWithInput(*model, tile_output_array)) {
+ model->arrays.erase(tile_output_array);
+ }
+}
+} // namespace
+
+bool ResolveTensorFlowTile::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ auto* binary_op = binary_it->get();
+ // Test for binary ops of types that we know how to resolve
+ if (binary_op->inputs.size() != 2) {
+ return false;
+ }
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv) {
+ return false;
+ }
+
+ Operator* const op[2] = {
+ GetOpWithOutput(*model, binary_op->inputs[0]),
+ GetOpWithOutput(*model, binary_op->inputs[1]),
+ };
+
+ // In the unlikely case where both operands are Tile, we can't infer the
+ // output
+ // size without the Tile nodes, so we have to bail out.
+ if (op[0] && op[0]->type == OperatorType::kTensorFlowTile && op[1] &&
+ op[1]->type == OperatorType::kTensorFlowTile) {
+ return false;
+ }
+
+ for (int i = 0; i < 2; i++) {
+ if (op[i] && op[i]->type == OperatorType::kTensorFlowTile) {
+ // We can only remove a Tile operator is no other op than the present
+ // binary op was consuming its tiled output.
+ if (CountOpsWithInput(*model, binary_op->inputs[i]) == 1) {
+ AddMessageF("Removing %s", LogName(*op[i]));
+ RemoveTileOperator(model, op[i], binary_op, i);
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
new file mode 100644
index 0000000000..8931498782
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
@@ -0,0 +1,31 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+tf_cc_test(
+ name = "resolve_constant_concatenation_test",
+ srcs = ["resolve_constant_concatenation_test.cc"],
+ deps = [
+ "//tensorflow/contrib/lite/toco:graph_transformations",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
new file mode 100644
index 0000000000..c6705ad305
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
@@ -0,0 +1,221 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+//#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+// A gmock matcher that check that elements of a float vector match to a given
+// tolerance.
+std::vector<testing::Matcher<float>> ArrayFloatNear(
+ const std::vector<float>& values, float max_abs_error = 1e-5) {
+ std::vector<testing::Matcher<float>> matchers;
+ matchers.reserve(values.size());
+ for (const float& v : values) {
+ matchers.emplace_back(testing::FloatNear(v, max_abs_error));
+ }
+ return matchers;
+}
+} // namespace
+
+// The following 3 tests make sure the concatenation operation on different axis
+// values match TensorFlow results listed below:
+//
+// x0 = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
+// x1 = [[[10, 11], [12, 13]], [[14, 15], [16, 17]]]
+// x2 = [[[20, 21], [22, 23]], [[24, 25], [26, 27]]]
+// x3 = [[[30, 31], [32, 33]], [[34, 35], [36, 37]]]
+//
+// ConcatAtAxis0 test:
+// t0 = tf.concat([x0, x1, x2, x3], 0)
+// [[[ 0 1]
+// [ 2 3]]
+//
+// [[ 4 5]
+// [ 6 7]]
+//
+// [[10 11]
+// [12 13]]
+//
+// [[14 15]
+// [16 17]]
+//
+// [[20 21]
+// [22 23]]
+//
+// [[24 25]
+// [26 27]]
+//
+// [[30 31]
+// [32 33]]
+//
+// [[34 35]
+// [36 37]]]
+//
+// ConcatAtAxis1 test:
+// t1 = tf.concat([x0, x1, x2, x3], 1)
+// [[[ 0 1]
+// [ 2 3]
+// [10 11]
+// [12 13]
+// [20 21]
+// [22 23]
+// [30 31]
+// [32 33]]
+//
+// [[ 4 5]
+// [ 6 7]
+// [14 15]
+// [16 17]
+// [24 25]
+// [26 27]
+// [34 35]
+// [36 37]]]
+//
+// ConcatAtAxis2 test:
+// t2 = tf.concat([x0, x1, x2, x3], 2)
+// [[[ 0 1 10 11 20 21 30 31]
+// [ 2 3 12 13 22 23 32 33]]
+//
+// [[ 4 5 14 15 24 25 34 35]
+// [ 6 7 16 17 26 27 36 37]]]
+
+class ResolveConstantConcatenationTest : public ::testing::Test {
+ protected:
+ ResolveConstantConcatenationTest() {}
+
+ // Prepare a hypothetical TOCO model with one Concatenation operator in it
+ // together with 4 arrays as its inputs.
+ // It receives the dimension of concatenation as input.
+ void PrepareModel(Model* model, int concat_dim) {
+ std::vector<string> concat_input_names = {"array0", "array1", "array2",
+ "array3"};
+
+ const int kDim = 3;
+ const int kElementPerDim = 2;
+ const int kBufSize = 8;
+ const int kNumArrays = 4;
+ static float in_buf[kNumArrays][kBufSize] = {
+ {0., 1., 2., 3., 4., 5., 6., 7.},
+ {10., 11., 12., 13., 14., 15., 16., 17.},
+ {20., 21., 22., 23., 24., 25., 26., 27.},
+ {30., 31., 32., 33., 34., 35., 36., 37.}};
+ int cnt = 0;
+ for (const string& concat_input_name : concat_input_names) {
+ Array& in_array = model->GetOrCreateArray(concat_input_name);
+ in_array.data_type = ArrayDataType::kFloat;
+
+ // Initialize shape for the input array.
+ Shape* in_array_shape = in_array.mutable_shape();
+ std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims();
+ for (int i = 0; i < kDim; i++) {
+ in_array_shape_dim->push_back(kElementPerDim);
+ }
+ auto& in_array_buffer =
+ in_array.GetMutableBuffer<toco::ArrayDataType::kFloat>();
+ in_array_buffer.data.resize(kBufSize);
+ float* buf_ptr =
+ in_array.GetMutableBuffer<toco::ArrayDataType::kFloat>().data.data();
+ std::copy(in_buf[cnt], in_buf[cnt] + kBufSize, buf_ptr);
+ cnt++;
+ }
+ auto* concatenation_op = new ConcatenationOperator;
+ concatenation_op->concat_dim = concat_dim;
+ concatenation_op->inputs = concat_input_names;
+ concatenation_op->outputs = {"concat_op_outputs"};
+ Array& out_array = model->GetOrCreateArray(concatenation_op->outputs[0]);
+ out_array.data_type = ArrayDataType::kFloat;
+ Shape* out_array_shape = out_array.mutable_shape();
+ std::vector<int>* out_array_shape_dim = out_array_shape->mutable_dims();
+ out_array_shape_dim->resize(kDim);
+ for (int i = 0; i < kDim; i++) {
+ if (i == concat_dim) {
+ (*out_array_shape_dim)[i] = kNumArrays * kElementPerDim;
+ } else {
+ (*out_array_shape_dim)[i] = kElementPerDim;
+ }
+ }
+ model->operators.push_back(std::unique_ptr<Operator>(concatenation_op));
+ }
+};
+
+TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) {
+ Model model;
+ const int concat_dim = 0;
+ PrepareModel(&model, concat_dim);
+
+ GraphTransformationsSet graph_transformation_set;
+ graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
+ EXPECT_THAT(model.arrays.size(), 5);
+ (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+ EXPECT_THAT(model.arrays.size(), 1);
+
+ auto& concatenated_array = (*model.arrays.begin()).second;
+ EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
+ ElementsAreArray(ArrayFloatNear(
+ {0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12.,
+ 13., 14., 15., 16., 17., 20., 21., 22., 23., 24., 25.,
+ 26., 27., 30., 31., 32., 33., 34., 35., 36., 37.})));
+}
+
+TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) {
+ Model model;
+ const int concat_dim = 1;
+ PrepareModel(&model, concat_dim);
+
+ GraphTransformationsSet graph_transformation_set;
+ graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
+ EXPECT_THAT(model.arrays.size(), 5);
+ (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+ EXPECT_THAT(model.arrays.size(), 1);
+
+ auto& concatenated_array = (*model.arrays.begin()).second;
+ EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
+ ElementsAreArray(ArrayFloatNear(
+ {0., 1., 2., 3., 10., 11., 12., 13., 20., 21., 22.,
+ 23., 30., 31., 32., 33., 4., 5., 6., 7., 14., 15.,
+ 16., 17., 24., 25., 26., 27., 34., 35., 36., 37.})));
+}
+
+TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) {
+ Model model;
+ const int concat_dim = 2;
+ PrepareModel(&model, concat_dim);
+
+ GraphTransformationsSet graph_transformation_set;
+ graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
+ EXPECT_THAT(model.arrays.size(), 5);
+ (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+ EXPECT_THAT(model.arrays.size(), 1);
+
+ auto& concatenated_array = (*model.arrays.begin()).second;
+ EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
+ ElementsAreArray(ArrayFloatNear(
+ {0., 1., 10., 11., 20., 21., 30., 31., 2., 3., 12.,
+ 13., 22., 23., 32., 33., 4., 5., 14., 15., 24., 25.,
+ 34., 35., 6., 7., 16., 17., 26., 27., 36., 37.})));
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
new file mode 100644
index 0000000000..4e273343df
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
@@ -0,0 +1,73 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+
+ // If a conv operation has an im2col array, yield: it should be dropped first.
+ if ((op->type == OperatorType::kConv) && (op->outputs.size() == 2)) {
+ return false;
+ }
+
+ Operator* ac_op = nullptr;
+ switch (op->fused_activation_function) {
+ case FusedActivationFunctionType::kRelu:
+ ac_op = new ReluOperator;
+ break;
+ case FusedActivationFunctionType::kRelu6:
+ ac_op = new Relu6Operator;
+ break;
+ case FusedActivationFunctionType::kRelu1:
+ ac_op = new Relu1Operator;
+ break;
+ default:
+ return false;
+ }
+
+ // At this point we know that the op has a fused activation function. At the
+ // moment that only happens with ops having a single output, may be
+ // relaxed in the future.
+ CHECK_EQ(op->outputs.size(), 1);
+
+ // Emplace unfused activation function, drop the fused one.
+ model->operators.emplace(it + 1, ac_op);
+ op->fused_activation_function = FusedActivationFunctionType::kNone;
+
+ // Wire up arrays, constructing a new intermediate array to connect the
+ // op to its new unfused activation function.
+ ac_op->outputs = op->outputs;
+ const string& tmp_array_name =
+ AvailableArrayName(*model, op->outputs[0] + "_unfused");
+ CHECK(!model->arrays.count(tmp_array_name));
+ model->GetOrCreateArray(tmp_array_name);
+ ac_op->inputs = {tmp_array_name};
+ op->outputs = {tmp_array_name};
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
new file mode 100644
index 0000000000..c889149ada
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -0,0 +1,1508 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/map.h"
+#include "google/protobuf/text_format.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/strip.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_util.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+using tensorflow::AttrValue;
+using tensorflow::DT_BOOL;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::DT_UINT8;
+using tensorflow::GraphDef;
+using tensorflow::NodeDef;
+using tensorflow::TensorProto;
+using tensorflow::TensorShapeProto;
+
+namespace toco {
+namespace {
+bool HasAttr(const NodeDef& node, const string& attr_name) {
+ return node.attr().count(attr_name) > 0;
+}
+
+const string& GetStringAttr(const NodeDef& node, const string& attr_name) {
+ CHECK(HasAttr(node, attr_name));
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kS);
+ return attr.s();
+}
+
+int GetIntAttr(const NodeDef& node, const string& attr_name) {
+ CHECK(HasAttr(node, attr_name)) << attr_name << " not found in:\n"
+ << node.DebugString();
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kI);
+ return attr.i();
+}
+
+float GetFloatAttr(const NodeDef& node, const string& attr_name) {
+ CHECK(HasAttr(node, attr_name));
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kF);
+ return attr.f();
+}
+
+bool GetBoolAttr(const NodeDef& node, const string& attr_name) {
+ CHECK(HasAttr(node, attr_name));
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kB);
+ return attr.b();
+}
+
+tensorflow::DataType GetDataTypeAttr(const NodeDef& node,
+ const string& attr_name) {
+ CHECK(HasAttr(node, attr_name));
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kType);
+ return attr.type();
+}
+
+const TensorShapeProto& GetShapeAttr(const NodeDef& node,
+ const string& attr_name) {
+ CHECK(HasAttr(node, attr_name));
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kShape);
+ return attr.shape();
+}
+
+const TensorProto& GetTensorAttr(const NodeDef& node, const string& attr_name) {
+ CHECK(HasAttr(node, attr_name));
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kTensor);
+ return attr.tensor();
+}
+
+const AttrValue::ListValue& GetListAttr(const NodeDef& node,
+ const string& attr_name) {
+ CHECK(HasAttr(node, attr_name));
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kList);
+ return attr.list();
+}
+
+ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
+ if (dtype == DT_UINT8)
+ return ArrayDataType::kUint8;
+ else if (dtype == DT_FLOAT)
+ return ArrayDataType::kFloat;
+ else if (dtype == DT_BOOL)
+ return ArrayDataType::kBool;
+ else if (dtype == DT_INT32)
+ return ArrayDataType::kInt32;
+ else if (dtype == DT_INT64)
+ return ArrayDataType::kInt64;
+ else
+ LOG(INFO) << "Unsupported data type in placehoder op: " << dtype;
+ return ArrayDataType::kNone;
+}
+
+void ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<
+ tensorflow::TensorShapeProto_Dim>& input_dims,
+ Shape* shape) {
+ std::vector<int> input_dims_only_sizes;
+ for (auto& d : input_dims) {
+ if (d.size() == 0) {
+ // Some TensorFlow shapes contain a 0 dim, effectively making
+ // them of flat size 0 even though they have other nonzero dims.
+ // This breaks our invariant, that array dims can't be 0.
+ // For now, tweaking this to record a 0-D shape instead.
+ input_dims_only_sizes.clear();
+ break;
+ }
+ input_dims_only_sizes.push_back(d.size());
+ }
+ *shape->mutable_dims() = input_dims_only_sizes;
+}
+
+void ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
+ CHECK_EQ(input_tensor.dtype(), DT_FLOAT);
+ const auto& input_shape = input_tensor.tensor_shape();
+ CHECK_LE(input_shape.dim_size(), 4);
+ ImportShape(input_shape.dim(), output_array->mutable_shape());
+ int input_flat_size = 1;
+ for (int k = 0; k < input_shape.dim_size(); k++) {
+ input_flat_size *= input_shape.dim(k).size();
+ }
+ auto& output_float_data =
+ output_array->GetMutableBuffer<ArrayDataType::kFloat>().data;
+ output_float_data.resize(input_flat_size);
+ if (input_tensor.float_val_size()) {
+ for (int i = 0; i < input_tensor.float_val_size(); i++) {
+ output_float_data[i] = input_tensor.float_val(i);
+ }
+ } else if (input_tensor.tensor_content().size() ==
+ input_flat_size * sizeof(float)) {
+ toco::port::CopyToBuffer(input_tensor.tensor_content(),
+ reinterpret_cast<char*>(output_float_data.data()));
+ } else {
+ LOG(FATAL) << "Neither input_content nor float_val have the right "
+ "dimensions for this float tensor.";
+ }
+}
+
+void ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
+ CHECK_EQ(input_tensor.dtype(), DT_INT32);
+ const auto& input_shape = input_tensor.tensor_shape();
+ CHECK_LE(input_shape.dim_size(), 4);
+ ImportShape(input_shape.dim(), output_array->mutable_shape());
+ int input_flat_size = 1;
+ for (int k = 0; k < input_shape.dim_size(); k++) {
+ input_flat_size *= input_shape.dim(k).size();
+ }
+ auto& output_int_data =
+ output_array->GetMutableBuffer<ArrayDataType::kInt32>().data;
+ output_int_data.resize(input_flat_size);
+ if (input_tensor.int_val_size()) {
+ for (int i = 0; i < input_tensor.int_val_size(); i++) {
+ output_int_data[i] = input_tensor.int_val(i);
+ }
+ } else if (input_tensor.tensor_content().size() ==
+ input_flat_size * sizeof(int32)) {
+ toco::port::CopyToBuffer(input_tensor.tensor_content(),
+ reinterpret_cast<char*>(output_int_data.data()));
+ } else {
+ LOG(FATAL) << "Neither input_content nor int_val have the right "
+ "dimensions for this int32 tensor.";
+ }
+}
+
+void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
+ CHECK_EQ(input_tensor.dtype(), DT_INT64);
+ const auto& input_shape = input_tensor.tensor_shape();
+ CHECK_LE(input_shape.dim_size(), 4);
+ ImportShape(input_shape.dim(), output_array->mutable_shape());
+ int input_flat_size = 1;
+ for (int k = 0; k < input_shape.dim_size(); k++) {
+ input_flat_size *= input_shape.dim(k).size();
+ }
+ auto& output_int_data =
+ output_array->GetMutableBuffer<ArrayDataType::kInt64>().data;
+ output_int_data.resize(input_flat_size);
+ if (input_tensor.int64_val_size()) {
+ for (int i = 0; i < input_tensor.int64_val_size(); i++) {
+ output_int_data[i] = input_tensor.int64_val(i);
+ }
+ } else if (input_tensor.tensor_content().size() ==
+ input_flat_size * sizeof(int64)) {
+ toco::port::CopyToBuffer(input_tensor.tensor_content(),
+ reinterpret_cast<char*>(output_int_data.data()));
+ } else {
+ LOG(FATAL) << "Neither input_content nor int64_val have the right "
+ "dimensions for this int64 tensor.";
+ }
+}
+
+// Count the number of inputs of a given node. If `drop_control_dependency` is
+// true, count the number of non-control-dependency inputs.
+size_t GetInputsCount(const NodeDef& node, bool drop_control_dependency) {
+ if (drop_control_dependency) {
+ for (size_t i = 0; i < node.input_size(); ++i) {
+ if (node.input(i)[0] == '^') {
+ LOG(INFO) << "Reached first control dependency input: "
+ << node.input(i);
+ return i;
+ }
+ }
+ return node.input_size();
+ } else {
+ return node.input_size();
+ }
+}
+
+void ConvertConstOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Const");
+ const auto& tensor = GetTensorAttr(node, "value");
+ const auto dtype = GetDataTypeAttr(node, "dtype");
+
+ auto& array = model->GetOrCreateArray(node.name());
+ array.data_type = dtype == DT_FLOAT
+ ? ArrayDataType::kFloat
+ : dtype == DT_INT32
+ ? ArrayDataType::kInt32
+ : dtype == DT_INT64 ? ArrayDataType::kInt64
+ : ArrayDataType::kNone;
+ if (dtype == DT_FLOAT) {
+ ImportFloatArray(tensor, &array);
+ } else if (dtype == DT_INT32) {
+ ImportInt32Array(tensor, &array);
+ } else if (dtype == DT_INT64) {
+ ImportInt64Array(tensor, &array);
+ } else {
+ // do nothing, silently ignore the Const data. For example, there are consts
+ // of string type. We just make a dummy buffer to indicate that this array
+ // does not rely on external input.
+ array.GetMutableBuffer<ArrayDataType::kNone>();
+ }
+}
+
+void ConvertConvOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Conv2D");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+
+ // We only support NHWC, which is the default data_format.
+ // So if data_format is not defined, we're all good.
+ if (node.attr().count("data_format")) {
+ CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
+ }
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+
+ const auto& input_name = node.input(0);
+ const auto& weights_name = node.input(1);
+ const auto& reordered_weights_name = weights_name + "_reordered";
+ // Check if a ReorderAxesOperator was already created for these weights
+ // (that happens when multiple layers share the same weights).
+ const Operator* existing_reorder =
+ GetOpWithOutput(*model, reordered_weights_name);
+ if (existing_reorder) {
+ // Check that it is safe to rely on the _reordered naming of the output
+ // array!
+ CHECK(existing_reorder->type == OperatorType::kReorderAxes);
+ } else {
+ // Create a new ReorderAxesOperator
+ auto* reorder = new ReorderAxesOperator;
+ reorder->inputs = {weights_name};
+ reorder->outputs = {reordered_weights_name};
+ reorder->input_axes_order = AxesOrder::kHWIO;
+ reorder->output_axes_order = AxesOrder::kOHWI;
+ model->operators.emplace_back(reorder);
+ }
+ auto* conv = new ConvOperator;
+ conv->inputs = {input_name, reordered_weights_name};
+ conv->outputs = {node.name()};
+ const auto& strides = GetListAttr(node, "strides");
+ CHECK_EQ(strides.i_size(), 4);
+ CHECK_EQ(strides.i(0), 1);
+ CHECK_EQ(strides.i(3), 1);
+ conv->stride_height = strides.i(1);
+ conv->stride_width = strides.i(2);
+ const auto& padding = GetStringAttr(node, "padding");
+ if (padding == "SAME") {
+ conv->padding.type = PaddingType::kSame;
+ } else if (padding == "VALID") {
+ conv->padding.type = PaddingType::kValid;
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ model->operators.emplace_back(conv);
+}
+
+void ConvertDepthwiseConvOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "DepthwiseConv2dNative");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+
+ // We only support NHWC, which is the default data_format.
+ // So if data_format is not defined, we're all good.
+ if (node.attr().count("data_format")) {
+ CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
+ }
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+
+ const auto& input_name = node.input(0);
+ const auto& weights_name = node.input(1);
+ const auto& reordered_weights_name = weights_name + "_reordered";
+ // Check if a ReorderAxesOperator was already created for these weights
+ // (that happens when multiple layers share the same weights).
+ const Operator* existing_reorder =
+ GetOpWithOutput(*model, reordered_weights_name);
+ if (existing_reorder) {
+ // Check that it is safe to rely on the _reordered naming of the output
+ // array!
+ CHECK(existing_reorder->type == OperatorType::kReorderAxes);
+ } else {
+ // Create a new ReorderAxesOperator
+ auto* reorder = new ReorderAxesOperator;
+ reorder->inputs = {weights_name};
+ reorder->outputs = {reordered_weights_name};
+ reorder->input_axes_order = AxesOrder::kHWIM;
+ reorder->output_axes_order = AxesOrder::k1HWO;
+ model->operators.emplace_back(reorder);
+ }
+ auto* conv = new DepthwiseConvOperator;
+ conv->inputs = {input_name, reordered_weights_name};
+ conv->outputs = {node.name()};
+ const auto& strides = GetListAttr(node, "strides");
+ CHECK_EQ(strides.i_size(), 4);
+ CHECK_EQ(strides.i(0), 1);
+ CHECK_EQ(strides.i(3), 1);
+ conv->stride_height = strides.i(1);
+ conv->stride_width = strides.i(2);
+ const auto& padding = GetStringAttr(node, "padding");
+ if (padding == "SAME") {
+ conv->padding.type = PaddingType::kSame;
+ } else if (padding == "VALID") {
+ conv->padding.type = PaddingType::kValid;
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ model->operators.emplace_back(conv);
+}
+
+void ConvertDepthToSpaceOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "DepthToSpace");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+ auto* op = new DepthToSpaceOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ op->block_size = GetIntAttr(node, "block_size");
+ QCHECK_GE(op->block_size, 2);
+ model->operators.emplace_back(op);
+}
+
+void ConvertSpaceToDepthOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "SpaceToDepth");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+ auto* op = new SpaceToDepthOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ op->block_size = GetIntAttr(node, "block_size");
+ QCHECK_GE(op->block_size, 2);
+ model->operators.emplace_back(op);
+}
+
+void ConvertBiasAddOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "BiasAdd");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ const auto& input_name = node.input(0);
+ const auto& bias_name = node.input(1);
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+ auto* biasadd = new AddOperator;
+ biasadd->inputs.push_back(input_name);
+ biasadd->inputs.push_back(bias_name);
+ biasadd->outputs.push_back(node.name());
+ model->operators.emplace_back(biasadd);
+}
+
+void ConvertReluOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Relu");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ auto* relu = new ReluOperator;
+ relu->inputs.push_back(input_name);
+ relu->outputs.push_back(node.name());
+ model->operators.emplace_back(relu);
+}
+
+void ConvertRelu6Operator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Relu6");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ auto* op = new Relu6Operator;
+ op->inputs.push_back(input_name);
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertLogisticOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Sigmoid");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ auto* op = new LogisticOperator;
+ op->inputs.push_back(input_name);
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertTanhOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Tanh");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ auto* op = new TanhOperator;
+ op->inputs.push_back(input_name);
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertDivOperator(const NodeDef& node, Model* model) {
+ CHECK(node.op() == "Div" || node.op() == "RealDiv");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new DivOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertIdentityOperator(const NodeDef& node, Model* model) {
+ CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
+ node.op() == "PlaceholderWithDefault");
+ auto* op = new TensorFlowIdentityOperator;
+ // Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have
+ // identity nodes with multiple inputs, but the other inputs seem
+ // to be gratuitous (in the case of rajeev_lstm.pb, these are
+ // enumerating the LSTM state arrays). We will just ignore extra
+ // inputs beyond the first input.
+ CHECK_GE(node.input_size(), 1);
+ const auto& input_name = node.input(0);
+ op->inputs.push_back(input_name);
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertFakeQuantWithMinMaxArgs(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ auto* op = new FakeQuantOperator;
+ op->inputs.push_back(node.input(0));
+ op->minmax.reset(new MinMax);
+ auto& minmax = *op->minmax;
+ minmax.min = GetFloatAttr(node, "min");
+ minmax.max = GetFloatAttr(node, "max");
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertFakeQuantWithMinMaxVars(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars");
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ CHECK(num_inputs == 3 || num_inputs == 4);
+ auto* op = new FakeQuantOperator;
+ for (int i = 0; i < 3; i++) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertRsqrtOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Rsqrt");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ auto* op = new TensorFlowRsqrtOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSqrtOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Sqrt");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ auto* op = new TensorFlowSqrtOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSqueezeOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Squeeze");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ auto* op = new SqueezeOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+
+ const auto& squeeze_dims = GetListAttr(node, "squeeze_dims");
+ for (int i = 0; i < squeeze_dims.i_size(); ++i) {
+ op->squeeze_dims.push_back(squeeze_dims.i(i));
+ }
+
+ model->operators.emplace_back(op);
+}
+
+void ConvertSquareOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Square");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ auto* op = new TensorFlowSquareOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertAddOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Add");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new AddOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertMulOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Mul");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new MulOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSubOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Sub");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new SubOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSumOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Sum");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowSumOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertTileOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Tile");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowTileOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSliceOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Slice");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 3);
+ auto* op = new SliceOperator;
+ for (int i = 0; i < 3; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertPadOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Pad");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new PadOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertShapeOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Shape");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ auto* op = new TensorFlowShapeOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSplitOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Split");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowSplitOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ const int num_split = GetIntAttr(node, "num_split");
+ op->outputs.push_back(node.name());
+ for (int i = 1; i < num_split; i++) {
+ op->outputs.push_back(absl::StrCat(node.name(), ":", i));
+ }
+ op->num_split = num_split;
+ model->operators.emplace_back(op);
+}
+
+void ConvertMergeOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Merge");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowMergeOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSwitchOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Switch");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowSwitchOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ // Switch operators have two outputs: "name" and "name:1".
+ op->outputs.push_back(node.name() + ":1");
+ model->operators.emplace_back(op);
+}
+void ConvertSoftmaxOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Softmax");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ auto* softmax = new SoftmaxOperator;
+ softmax->inputs.push_back(input_name);
+ softmax->outputs.push_back(node.name());
+ // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter.
+ CHECK(!node.attr().count("beta")); // Stab in the dark, just in case.
+ softmax->beta = 1.f;
+ model->operators.emplace_back(softmax);
+}
+
+void ConvertLRNOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "LRN");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ auto* lrn = new LocalResponseNormalizationOperator;
+ lrn->inputs.push_back(input_name);
+ lrn->outputs.push_back(node.name());
+ lrn->range = GetIntAttr(node, "depth_radius");
+ lrn->bias = GetFloatAttr(node, "bias");
+ lrn->alpha = GetFloatAttr(node, "alpha");
+ lrn->beta = GetFloatAttr(node, "beta");
+ model->operators.emplace_back(lrn);
+}
+
+void ConvertMaxPoolOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "MaxPool");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ if (HasAttr(node, "T")) {
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+ } else {
+ LOG(WARNING) << "Found MaxPool operator missing 'T' attribute";
+ }
+ auto* maxpool = new MaxPoolOperator;
+ maxpool->inputs.push_back(input_name);
+ maxpool->outputs.push_back(node.name());
+ const auto& strides = GetListAttr(node, "strides");
+ CHECK_EQ(strides.i_size(), 4);
+ CHECK_EQ(strides.i(0), 1);
+ CHECK_EQ(strides.i(3), 1);
+ maxpool->stride_height = strides.i(1);
+ maxpool->stride_width = strides.i(2);
+ const auto& ksize = GetListAttr(node, "ksize");
+ CHECK_EQ(ksize.i_size(), 4);
+ CHECK_EQ(ksize.i(0), 1);
+ CHECK_EQ(ksize.i(3), 1);
+ maxpool->kheight = ksize.i(1);
+ maxpool->kwidth = ksize.i(2);
+ const auto& padding = GetStringAttr(node, "padding");
+ if (padding == "SAME") {
+ maxpool->padding.type = PaddingType::kSame;
+ } else if (padding == "VALID") {
+ maxpool->padding.type = PaddingType::kValid;
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ model->operators.emplace_back(maxpool);
+}
+
+void ConvertAvgPoolOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "AvgPool");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+ auto* avgpool = new AveragePoolOperator;
+ avgpool->inputs.push_back(input_name);
+ avgpool->outputs.push_back(node.name());
+ const auto& strides = GetListAttr(node, "strides");
+ CHECK_EQ(strides.i_size(), 4);
+ CHECK_EQ(strides.i(0), 1);
+ CHECK_EQ(strides.i(3), 1);
+ avgpool->stride_height = strides.i(1);
+ avgpool->stride_width = strides.i(2);
+ const auto& ksize = GetListAttr(node, "ksize");
+ CHECK_EQ(ksize.i_size(), 4);
+ CHECK_EQ(ksize.i(0), 1);
+ CHECK_EQ(ksize.i(3), 1);
+ avgpool->kheight = ksize.i(1);
+ avgpool->kwidth = ksize.i(2);
+ const auto& padding = GetStringAttr(node, "padding");
+ if (padding == "SAME") {
+ avgpool->padding.type = PaddingType::kSame;
+ } else if (padding == "VALID") {
+ avgpool->padding.type = PaddingType::kValid;
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ model->operators.emplace_back(avgpool);
+}
+
+void ConvertReshapeOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Reshape");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowReshapeOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertMatMulOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "MatMul");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ // Transpose flags should be easy to support, but we don't have a
+ // GraphDef with them to test on at the moment.
+ CHECK_EQ(GetBoolAttr(node, "transpose_a"), false);
+ CHECK_EQ(GetBoolAttr(node, "transpose_b"), false);
+ const auto& input_name = node.input(0);
+ const auto& weights_name = node.input(1);
+ const auto& reordered_weights_name = weights_name + "_reordered";
+ // Check if a ReorderAxesOperator was already created for these weights
+ // (that happens when multiple layers share the same weights).
+ const Operator* existing_reorder =
+ GetOpWithOutput(*model, reordered_weights_name);
+ if (existing_reorder) {
+ // Check that it is safe to rely on the _reordered naming of the output
+ // array!
+ CHECK(existing_reorder->type == OperatorType::kReorderAxes);
+ } else {
+ // Create a new ReorderAxesOperator
+ auto* reorder = new ReorderAxesOperator;
+ reorder->inputs = {weights_name};
+ reorder->outputs = {reordered_weights_name};
+ reorder->input_axes_order = AxesOrder::kRC;
+ reorder->output_axes_order = AxesOrder::kCR;
+ model->operators.emplace_back(reorder);
+ }
+ auto* matmul = new TensorFlowMatMulOperator;
+ matmul->inputs = {input_name, reordered_weights_name};
+ matmul->outputs = {node.name()};
+ model->operators.emplace_back(matmul);
+}
+
+void ConvertConcatOperator(const NodeDef& node, Model* model) {
+ Operator* op = nullptr;
+ if (node.op() == "Concat") {
+ op = new TensorFlowConcatOperator;
+ } else if (node.op() == "ConcatV2") {
+ op = new TensorFlowConcatV2Operator;
+ } else {
+ LOG(FATAL) << "Expected Concat or ConcatV2";
+ }
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ CHECK_GE(num_inputs, 2);
+ CHECK_EQ(num_inputs, 1 + GetIntAttr(node, "N"));
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertAllOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "All");
+ auto* op = new TensorFlowAllOperator;
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertAssertOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Assert");
+ auto* op = new TensorFlowAssertOperator;
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertLessOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Less");
+ auto* op = new TensorFlowLessOperator;
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertLessEqualOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "LessEqual");
+ auto* op = new TensorFlowLessEqualOperator;
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertGreaterOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Greater");
+ auto* op = new TensorFlowGreaterOperator;
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertGreaterEqualOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "GreaterEqual");
+ auto* op = new TensorFlowGreaterEqualOperator;
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertMaxOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Max");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowMaxOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertMinOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Min");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowMinOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertMaximumOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Maximum");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowMaximumOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertMinimumOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Minimum");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowMinimumOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertUnsupportedOperator(const NodeDef& node, Model* model) {
+ LOG(INFO) << "Converting unsupported operation: " << node.op();
+ auto* op = new TensorFlowUnsupportedOperator;
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ op->tensorflow_op = node.op();
+ node.SerializeToString(&op->tensorflow_node_def);
+ model->operators.emplace_back(op);
+ if (HasAttr(node, "_output_quantized")) {
+ op->quantized = GetBoolAttr(node, "_output_quantized");
+ }
+ if (HasAttr(node, "_output_types")) {
+ const auto& output_types = GetListAttr(node, "_output_types");
+ for (int i = 0; i < output_types.type_size(); ++i) {
+ op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
+ }
+ }
+}
+
+void ConvertStridedSliceOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "StridedSlice");
+ CHECK_EQ(node.input_size(), 4);
+
+ // Only a subset of the full TF op functionality is supported now.
+ if ( // No 64-bit indices.
+ GetDataTypeAttr(node, "Index") != DT_INT32 ||
+ // No dimensionality changes.
+ GetIntAttr(node, "new_axis_mask") != 0 ||
+ GetIntAttr(node, "shrink_axis_mask") != 0 ||
+ // No sparse indices.
+ GetIntAttr(node, "ellipsis_mask") != 0 ||
+ // Only 4D tensors are supported.
+ GetIntAttr(node, "begin_mask") > 15 ||
+ GetIntAttr(node, "end_mask") > 15) {
+ ConvertUnsupportedOperator(node, model);
+ return;
+ }
+
+ auto* op = new StridedSliceOperator;
+ for (const auto& input : node.input()) {
+ op->inputs.push_back(input);
+ }
+ op->outputs.push_back(node.name());
+
+ op->begin_mask = GetIntAttr(node, "begin_mask");
+ op->ellipsis_mask = GetIntAttr(node, "ellipsis_mask");
+ op->end_mask = GetIntAttr(node, "end_mask");
+ op->new_axis_mask = GetIntAttr(node, "new_axis_mask");
+ op->shrink_axis_mask = GetIntAttr(node, "shrink_axis_mask");
+ model->operators.emplace_back(op);
+}
+
+void ConvertPlaceholderOperator(const NodeDef& node, Model* model) {
+ CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput");
+ if (node.op() == "Placeholder") {
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 0);
+ }
+ auto& array = model->GetOrCreateArray(node.name());
+ if (node.attr().count("dtype")) {
+ array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype"));
+ }
+ if (node.attr().count("shape")) {
+ const auto& shape = GetShapeAttr(node, "shape");
+ auto num_dims = shape.dim_size();
+ bool has_wildcard = false;
+ for (std::size_t i = 0; i < num_dims; i++) {
+ if (shape.dim(i).size() == -1) {
+ has_wildcard = true;
+ }
+ }
+ // TODO(b/62716978): This logic needs to be revisted. During dims
+ // refactoring it is an interim fix.
+ if (num_dims > 0 && !has_wildcard) {
+ auto& dst_array_dims = *array.mutable_shape()->mutable_dims();
+ dst_array_dims.resize(num_dims);
+ for (std::size_t i = 0; i < num_dims; i++) {
+ dst_array_dims[i] = shape.dim(i).size();
+ }
+ }
+ }
+}
+
+void ConvertNoOpOperator(const NodeDef& node, Model* model) {}
+
+ArrayDataType GetArrayDataType(tensorflow::DataType tf_data_type) {
+ if (tf_data_type == DT_UINT8) {
+ return ArrayDataType::kUint8;
+ } else if (tf_data_type == DT_INT32) {
+ return ArrayDataType::kInt32;
+ } else if (tf_data_type == DT_FLOAT) {
+ return ArrayDataType::kFloat;
+ } else {
+ return ArrayDataType::kNone;
+ }
+}
+
+void ConvertCastOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Cast");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT");
+ const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT");
+ CHECK(tf_src_dtype == DT_UINT8 || tf_src_dtype == DT_INT32 ||
+ tf_src_dtype == DT_FLOAT);
+ CHECK(tf_dst_dtype == DT_UINT8 || tf_dst_dtype == DT_INT32 ||
+ tf_dst_dtype == DT_FLOAT);
+ CHECK_NE(tf_src_dtype, tf_dst_dtype)
+ << "Same input and output data type. No need to cast.";
+ auto* op = new CastOperator;
+ op->src_data_type = GetArrayDataType(tf_src_dtype);
+ op->dst_data_type = GetArrayDataType(tf_dst_dtype);
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertFloorOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Floor");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto data_type = GetDataTypeAttr(node, "T");
+ CHECK(data_type == DT_FLOAT);
+ auto* op = new FloorOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertGatherOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Gather");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
+ CHECK(indices_data_type == DT_INT32);
+ auto* op = new GatherOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertResizeBilinearOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "ResizeBilinear");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new ResizeBilinearOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertBatchNormWithGlobalNormalizationOperator(const NodeDef& node,
+ Model* model) {
+ CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 5);
+
+ // TODO(ahentz): to really match tensorflow we need to add variance_epsilon
+ // to the input, before feeding it into TensorFlowRsqrtOperator.
+ // CHECK_EQ(GetFloatAttr(node, "variance_epsilon"), 0.001f);
+
+ string multiplier = node.name() + "_mul";
+ if (GetBoolAttr(node, "scale_after_normalization")) {
+ // Create graph:
+ // v -> RSQRT ->
+ // MUL -> multiplier
+ // gamma ----->
+ string rsqrt = node.name() + "_rsqrt";
+
+ auto* rsqrt_op = new TensorFlowRsqrtOperator;
+ rsqrt_op->inputs.push_back(node.input(2));
+ rsqrt_op->outputs.push_back(rsqrt);
+ model->operators.emplace_back(rsqrt_op);
+
+ auto* mul_op = new MulOperator;
+ mul_op->inputs.push_back(rsqrt);
+ mul_op->inputs.push_back(node.input(4));
+ mul_op->outputs.push_back(multiplier);
+ model->operators.emplace_back(mul_op);
+ } else {
+ // Create graph:
+ // v -> RSQRT -> multiplier
+ auto* rsqrt_op = new TensorFlowRsqrtOperator;
+ rsqrt_op->inputs.push_back(node.input(2));
+ rsqrt_op->outputs.push_back(multiplier);
+ model->operators.emplace_back(rsqrt_op);
+ }
+
+ auto* op = new BatchNormalizationOperator;
+ op->global_normalization = true;
+
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->inputs.push_back(multiplier);
+ op->inputs.push_back(node.input(3));
+ op->outputs.push_back(node.name());
+
+ model->operators.emplace_back(op);
+}
+
+void ConvertFusedBatchNormOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "FusedBatchNorm");
+ CHECK_EQ(node.input_size(), 5);
+
+ // Declare shortcuts for the inputs.
+ const string& gamma_input = node.input(1);
+ const string& beta_input = node.input(2);
+ const string& moving_mean_input = node.input(3);
+ const string& moving_variance_input = node.input(4);
+
+ // Create an array holding the epsilon value (typically, 0.001).
+ const string epsilon_array_name = node.name() + "_epsilon_array";
+ auto& epsilon_array = model->GetOrCreateArray(epsilon_array_name);
+ epsilon_array.data_type = ArrayDataType::kFloat;
+ *epsilon_array.mutable_shape()->mutable_dims() = {1};
+ epsilon_array.GetMutableBuffer<ArrayDataType::kFloat>().data.push_back(
+ GetFloatAttr(node, "epsilon"));
+
+ // Add epsilon to the moving variance.
+ const string epsilon_add_op_name = node.name() + "_epsilon";
+ auto* epsilon_add_op = new AddOperator;
+ epsilon_add_op->inputs.push_back(moving_variance_input);
+ epsilon_add_op->inputs.push_back(epsilon_array_name);
+ epsilon_add_op->outputs.push_back(epsilon_add_op_name);
+ model->operators.emplace_back(epsilon_add_op);
+
+ // Take the inverse square root of the (variance + epsilon).
+ const string rsqrt_op_name = node.name() + "_rsqrt";
+ auto* rsqrt_op = new TensorFlowRsqrtOperator;
+ rsqrt_op->inputs.push_back(epsilon_add_op_name);
+ rsqrt_op->outputs.push_back(rsqrt_op_name);
+ model->operators.emplace_back(rsqrt_op);
+
+ // Multiply the result by gamma.
+ const string multiplier = node.name() + "_mul";
+ auto* mul_op = new MulOperator;
+ mul_op->inputs.push_back(rsqrt_op_name);
+ mul_op->inputs.push_back(gamma_input);
+ mul_op->outputs.push_back(multiplier);
+ model->operators.emplace_back(mul_op);
+
+ // Now we have all required inputs for the BatchNormalizationOperator.
+ auto* op = new BatchNormalizationOperator;
+ op->global_normalization = true;
+
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(moving_mean_input);
+ op->inputs.push_back(multiplier);
+ op->inputs.push_back(beta_input);
+ op->outputs.push_back(node.name());
+
+ model->operators.emplace_back(op);
+}
+
+void ConvertSpaceToBatchNDOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "SpaceToBatchND");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 3);
+ CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
+ CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32);
+ auto* op = new SpaceToBatchNDOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->inputs.push_back(node.input(2));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertBatchToSpaceNDOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "BatchToSpaceND");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 3);
+ CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
+ CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32);
+ auto* op = new BatchToSpaceNDOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->inputs.push_back(node.input(2));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertMeanOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Mean");
+ CHECK_EQ(node.input_size(), 2);
+ auto* op = new MeanOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSvdfOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Svdf");
+ bool has_bias = (node.input_size() == 4);
+ auto* op = new SvdfOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->inputs.push_back(node.input(2));
+ if (has_bias) {
+ op->inputs.push_back(node.input(3));
+ }
+ op->outputs.push_back(node.name() + "_state");
+ op->outputs.push_back(node.name());
+ if (node.attr().at("ActivationFunction").s() == "Relu") {
+ op->fused_activation_function = FusedActivationFunctionType::kRelu;
+ } else {
+ op->fused_activation_function = FusedActivationFunctionType::kNone;
+ }
+ op->rank = node.attr().at("Rank").i();
+ model->operators.emplace_back(op);
+}
+
+void StripCaretFromArrayNames(Model* model) {
+ for (auto& op : model->operators) {
+ for (auto& input : op->inputs) {
+ input = string(absl::StripPrefix(input, "^"));
+ }
+ for (auto& output : op->outputs) {
+ output = string(absl::StripPrefix(output, "^"));
+ }
+ }
+ for (auto& array : model->arrays) {
+ if (absl::StartsWith(array.first, "^")) {
+ LOG(FATAL) << "What?";
+ }
+ }
+}
+
+void AddExtraOutputsFedIntoOtherOps(Model* model) {
+ for (const auto& consumer_op : model->operators) {
+ for (const string& input : consumer_op->inputs) {
+ const std::vector<string>& split = absl::StrSplit(input, ':');
+ if (split.size() != 2) {
+ continue;
+ }
+ int output_index = 0;
+ if (!absl::SimpleAtoi(split[1], &output_index)) {
+ continue;
+ }
+ auto* producer_op = GetOpWithOutput(*model, split[0]);
+ if (!producer_op) {
+ continue;
+ }
+ while (producer_op->outputs.size() <= output_index) {
+ using toco::port::StringF;
+ producer_op->outputs.push_back(
+ StringF("%s:%d", split[0], producer_op->outputs.size()));
+ }
+ }
+ }
+}
+
+bool InlineAllFunctions(GraphDef* graphdef) {
+ if (graphdef->library().function().empty()) {
+ VLOG(kLogLevelModelUnchanged) << "No functions to inline.";
+ return false;
+ }
+
+ // Override "_noinline" attribute on all functions
+ GraphDef graphdef_copy(*graphdef);
+ for (auto& function :
+ (*graphdef_copy.mutable_library()->mutable_function())) {
+ auto* attributes = function.mutable_attr();
+ if (attributes->count(tensorflow::kNoInlineAttr) != 0) {
+ (*attributes)[tensorflow::kNoInlineAttr].set_b(false);
+ }
+ }
+
+ // Construct minimum resources needed to use ExpandInlineFunctions().
+ tensorflow::SessionOptions options;
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", 1});
+ std::vector<tensorflow::Device*> devices;
+ TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
+ options, "/job:localhost/replica:0/task:0", &devices));
+
+ tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
+ graphdef_copy.library());
+ tensorflow::DeviceMgr device_mgr(devices);
+ tensorflow::OptimizerOptions o_opts;
+ tensorflow::ProcessFunctionLibraryRuntime pflr(
+ &device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
+ o_opts, nullptr);
+ tensorflow::FunctionLibraryRuntime* flr;
+ flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0");
+
+ tensorflow::Graph graph(fld);
+ tensorflow::GraphConstructorOptions gc_opts;
+ TF_CHECK_OK(
+ tensorflow::ConvertGraphDefToGraph(gc_opts, graphdef_copy, &graph));
+
+ // Iterate over the graph until there are no more nodes to be inlined.
+ bool graph_modified = false;
+ while (tensorflow::ExpandInlineFunctions(flr, &graph)) {
+ graph_modified = true;
+ LOG(INFO) << "Found functions that were inlined.";
+ }
+
+ // Output inlined graph
+ if (graph_modified) {
+ graph.ToGraphDef(graphdef);
+ }
+ return graph_modified;
+}
+} // namespace
+
+std::unique_ptr<Model> ImportTensorFlowGraphDef(const ModelFlags& model_flags,
+ const GraphDef& tf_graph) {
+ LogDumpGraphDef(kLogLevelModelChanged, "AT IMPORT", tf_graph);
+
+ GraphDef inlined_graph(tf_graph);
+ if (InlineAllFunctions(&inlined_graph)) {
+ LogDumpGraphDef(kLogLevelModelChanged, "AFTER INLINING", inlined_graph);
+ }
+
+ Model* model = new Model;
+ ResolveModelFlags(model_flags, model);
+
+ for (const auto& node : inlined_graph.node()) {
+ if (node.op() == "Const") {
+ ConvertConstOperator(node, model);
+ } else if (node.op() == "Conv2D") {
+ ConvertConvOperator(node, model);
+ } else if (node.op() == "DepthwiseConv2dNative") {
+ ConvertDepthwiseConvOperator(node, model);
+ } else if (node.op() == "DepthToSpace") {
+ ConvertDepthToSpaceOperator(node, model);
+ } else if (node.op() == "SpaceToDepth") {
+ ConvertSpaceToDepthOperator(node, model);
+ } else if (node.op() == "BiasAdd") {
+ ConvertBiasAddOperator(node, model);
+ } else if (node.op() == "Relu") {
+ ConvertReluOperator(node, model);
+ } else if (node.op() == "Relu6") {
+ ConvertRelu6Operator(node, model);
+ } else if (node.op() == "Sigmoid") {
+ ConvertLogisticOperator(node, model);
+ } else if (node.op() == "Tanh") {
+ ConvertTanhOperator(node, model);
+ } else if (node.op() == "MaxPool") {
+ ConvertMaxPoolOperator(node, model);
+ } else if (node.op() == "AvgPool") {
+ ConvertAvgPoolOperator(node, model);
+ } else if (node.op() == "Reshape") {
+ ConvertReshapeOperator(node, model);
+ } else if (node.op() == "MatMul") {
+ ConvertMatMulOperator(node, model);
+ } else if (node.op() == "Div" || node.op() == "RealDiv") {
+ ConvertDivOperator(node, model);
+ } else if (node.op() == "Identity" || node.op() == "CheckNumerics") {
+ ConvertIdentityOperator(node, model);
+ } else if (node.op() == "FakeQuantWithMinMaxVars") {
+ ConvertFakeQuantWithMinMaxVars(node, model);
+ } else if (node.op() == "FakeQuantWithMinMaxArgs") {
+ ConvertFakeQuantWithMinMaxArgs(node, model);
+ } else if (node.op() == "Rsqrt") {
+ ConvertRsqrtOperator(node, model);
+ } else if (node.op() == "Squeeze") {
+ ConvertSqueezeOperator(node, model);
+ } else if (node.op() == "Sqrt") {
+ ConvertSqrtOperator(node, model);
+ } else if (node.op() == "Square") {
+ ConvertSquareOperator(node, model);
+ } else if (node.op() == "Add") {
+ ConvertAddOperator(node, model);
+ } else if (node.op() == "Mul") {
+ ConvertMulOperator(node, model);
+ } else if (node.op() == "Sub") {
+ ConvertSubOperator(node, model);
+ } else if (node.op() == "Sum") {
+ ConvertSumOperator(node, model);
+ } else if (node.op() == "Tile") {
+ ConvertTileOperator(node, model);
+ } else if (node.op() == "Concat" || node.op() == "ConcatV2") {
+ ConvertConcatOperator(node, model);
+ } else if (node.op() == "LRN") {
+ ConvertLRNOperator(node, model);
+ } else if (node.op() == "Softmax") {
+ ConvertSoftmaxOperator(node, model);
+ } else if (node.op() == "All") {
+ ConvertAllOperator(node, model);
+ } else if (node.op() == "Assert") {
+ ConvertAssertOperator(node, model);
+ } else if (node.op() == "Less") {
+ ConvertLessOperator(node, model);
+ } else if (node.op() == "LessEqual") {
+ ConvertLessEqualOperator(node, model);
+ } else if (node.op() == "Greater") {
+ ConvertGreaterOperator(node, model);
+ } else if (node.op() == "GreaterEqual") {
+ ConvertGreaterEqualOperator(node, model);
+ } else if (node.op() == "Max") {
+ ConvertMaxOperator(node, model);
+ } else if (node.op() == "Min") {
+ ConvertMinOperator(node, model);
+ } else if (node.op() == "Maximum") {
+ ConvertMaximumOperator(node, model);
+ } else if (node.op() == "Minimum") {
+ ConvertMinimumOperator(node, model);
+ } else if (node.op() == "Merge") {
+ ConvertMergeOperator(node, model);
+ } else if (node.op() == "Pad") {
+ ConvertPadOperator(node, model);
+ } else if (node.op() == "StridedSlice") {
+ ConvertStridedSliceOperator(node, model);
+ } else if (node.op() == "Shape") {
+ ConvertShapeOperator(node, model);
+ } else if (node.op() == "Slice") {
+ ConvertSliceOperator(node, model);
+ } else if (node.op() == "Split") {
+ ConvertSplitOperator(node, model);
+ } else if (node.op() == "Switch") {
+ ConvertSwitchOperator(node, model);
+ } else if (node.op() == "Placeholder") {
+ ConvertPlaceholderOperator(node, model);
+ } else if (node.op() == "PlaceholderWithDefault") {
+ ConvertIdentityOperator(node, model);
+ } else if (node.op() == "LegacyFedInput") {
+ ConvertPlaceholderOperator(node, model);
+ } else if (node.op() == "NoOp") {
+ ConvertNoOpOperator(node, model);
+ } else if (node.op() == "Cast") {
+ ConvertCastOperator(node, model);
+ } else if (node.op() == "Floor") {
+ ConvertFloorOperator(node, model);
+ } else if (node.op() == "Gather") {
+ ConvertGatherOperator(node, model);
+ } else if (node.op() == "ResizeBilinear") {
+ ConvertResizeBilinearOperator(node, model);
+ } else if (node.op() == "BatchNormWithGlobalNormalization") {
+ ConvertBatchNormWithGlobalNormalizationOperator(node, model);
+ } else if (node.op() == "FusedBatchNorm") {
+ ConvertFusedBatchNormOperator(node, model);
+ } else if (node.op() == "SpaceToBatchND") {
+ ConvertSpaceToBatchNDOperator(node, model);
+ } else if (node.op() == "BatchToSpaceND") {
+ ConvertBatchToSpaceNDOperator(node, model);
+ } else if (node.op() == "Mean") {
+ ConvertMeanOperator(node, model);
+ } else if (node.op() == "Svdf") {
+ ConvertSvdfOperator(node, model);
+ } else {
+ ConvertUnsupportedOperator(node, model);
+ }
+ }
+
+ StripCaretFromArrayNames(model);
+ AddExtraOutputsFedIntoOtherOps(model);
+ FixNoMissingArray(model);
+ FixNoOrphanedArray(model);
+ FixOperatorOrdering(model);
+ CheckInvariants(*model);
+
+ // if rnn state arrays are constant, make them transient
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ model->GetArray(rnn_state.state_array()).buffer = nullptr;
+ }
+
+ return std::unique_ptr<Model>(model);
+}
+
+std::unique_ptr<Model> ImportTensorFlowGraphDef(
+ const ModelFlags& model_flags, const string& input_file_contents) {
+ std::unique_ptr<GraphDef> tf_graph(new GraphDef);
+ CHECK(ParseFromStringEitherTextOrBinary(input_file_contents, tf_graph.get()));
+
+ std::unique_ptr<GraphDef> pruned_graph =
+ MaybeReplaceCompositeSubgraph(*tf_graph);
+ if (pruned_graph) {
+ tf_graph = std::move(pruned_graph);
+ }
+ return ImportTensorFlowGraphDef(model_flags, *tf_graph);
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h
new file mode 100644
index 0000000000..d2eb423ca4
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.h
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_
+
+#include <memory>
+#include <string>
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+
+namespace toco {
+
+std::unique_ptr<Model> ImportTensorFlowGraphDef(
+ const ModelFlags& model_flags, const tensorflow::GraphDef& graph_def);
+
+std::unique_ptr<Model> ImportTensorFlowGraphDef(
+ const ModelFlags& model_flags, const string& input_file_contents);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
new file mode 100644
index 0000000000..d992f8458f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -0,0 +1,1372 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
+
+#include <initializer_list>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+enum class OperatorType {
+ kNone,
+ // General-purpose neural network operators.
+ kAdd,
+ kAveragePool,
+ kBatchNormalization,
+ kConv,
+ kConcatenation,
+ kDepthwiseConv,
+ kDepthToSpace,
+ kSpaceToDepth,
+ kDequantize,
+ kDiv,
+ kFullyConnected,
+ kL2Normalization,
+ kL2Pool,
+ kLstmCell,
+ kLocalResponseNormalization,
+ kLogistic,
+ kMaxPool,
+ kFakeQuant,
+ kMul,
+ kRelu,
+ kRelu1,
+ kRelu6,
+ kSoftmax,
+ kSub,
+ kTanh,
+ kCast,
+ kFloor,
+ kGather,
+ kResizeBilinear,
+ kSpaceToBatchND,
+ kBatchToSpaceND,
+ kPad,
+ kStridedSlice,
+ kSlice,
+ kSqueeze,
+ kMean,
+ // The SVDF Op is a decomposition of a densely connected Op into
+ // low rank filters. For details:
+ // https://research.google.com/pubs/pub43813.html
+ kSvdf,
+ // Special operators used for importing TensorFlow nodes.
+ // The general intent is to have some graph transformation either
+ // drop them or rewrite them as general-purpose operators.
+ kTensorFlowAll,
+ kTensorFlowAssert,
+ kTensorFlowConcat,
+ kTensorFlowConcatV2,
+ kTensorFlowGreater,
+ kTensorFlowGreaterEqual,
+ kTensorFlowIdentity,
+ kTensorFlowLess,
+ kTensorFlowLessEqual,
+ kTensorFlowMax,
+ kTensorFlowMaximum,
+ kTensorFlowMin,
+ kTensorFlowMinimum,
+ kTensorFlowMatMul,
+ kTensorFlowMerge,
+ kTensorFlowReshape,
+ kTensorFlowRsqrt,
+ kTensorFlowShape,
+ kTensorFlowSplit,
+ kTensorFlowSqrt,
+ kTensorFlowSquare,
+ kTensorFlowSum,
+ kTensorFlowSwitch,
+ kTensorFlowTile,
+ // An unsupported TF operation. It's only needed to be able to represent TF
+ // graph internally and is expected to be dropped by graph transformations.
+ kTensorFlowUnsupported,
+ // Finally, TensorFlow uses different conventions for axes ordering,
+ // see AxesOrder, and this cannot always be resolved at the time of importing
+ // nodes, as TensorFlow parameters may be constant-expression subgraphs
+ // instead of being given as plain constant arrays. So we need to insert
+ // special nodes in the graph to shuffle axes.
+ kReorderAxes,
+};
+
+// Helper to deal with TensorFlow arrays using a different ordering of
+// dimensions
+// ("axes") than our own.
+// TODO(benoitjacob): Ultimately, we shouldn't have any "ordering" of axes,
+// we should have associative arrays mapping symbolic axes identifiers (like
+// "output_depth") to dimensions. We would then not need this anymore.
+enum class AxesOrder {
+ kOneAxis, // one-dimensional array, one unique axis.
+ kCR, // column-major matrix storage order. Our standard.
+ kRC, // row-major matrix storage order. TensorFlow default.
+ kOHWI, // Our standard for conv weights
+ kHWIO, // TensorFlow conv weights
+ k1HWO, // Our standard for DepthwiseConv weights
+ kHWIM, // TensorFlow DepthwiseConv weights
+ kNHWC, // TensorFlow activations
+};
+
+// The type of the scalars in an array.
+// Note that that does not by itself tell whether the values in the array are
+// real (are literally interpreted as real numbers) or quantized (only acquire
+// a meaning as real numbers in conjuction with QuantizationParams).
+//
+// In practice though:
+// float values are always real
+// uint8 values are always quantized
+// int32 values are either real or quantized (depending on whether
+// QuantizationParams are present).
+// other types are unused at the moment.
+//
+// kNone means that we don't know the data type yet, or that we don't care
+// because we'll be dropping the array anyway (e.g. some exotic array types
+// may be involved only in debug-only subgraphs that we may not be interested
+// in actually supporting).
+enum class ArrayDataType { kNone, kBool, kFloat, kUint8, kInt32, kInt64 };
+
+// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type
+template <ArrayDataType A>
+struct DataTypeImpl {};
+template <>
+struct DataTypeImpl<ArrayDataType::kNone> {
+ typedef int Type;
+};
+template <>
+struct DataTypeImpl<ArrayDataType::kBool> {
+ typedef bool Type;
+};
+template <>
+struct DataTypeImpl<ArrayDataType::kFloat> {
+ typedef float Type;
+};
+template <>
+struct DataTypeImpl<ArrayDataType::kUint8> {
+ typedef uint8 Type;
+};
+template <>
+struct DataTypeImpl<ArrayDataType::kInt32> {
+ typedef int32 Type;
+};
+template <>
+struct DataTypeImpl<ArrayDataType::kInt64> {
+ typedef int64 Type;
+};
+
+template <ArrayDataType A>
+using DataType = typename DataTypeImpl<A>::Type;
+
+// Base class for type-specific buffer types.
+struct GenericBuffer {
+ // Non-default-constructible: only ArrayDataType-specific subclass
+ // objects may be constructed.
+ GenericBuffer() = delete;
+ // Non-copyable-or-movable: we should only store pointers-to-Buffer
+ // in containers, not Operators themselves, so there should be no
+ // copy or move.
+ GenericBuffer(const GenericBuffer&) = delete;
+ GenericBuffer(const GenericBuffer&&) = delete;
+
+ // We need a virtual destructor so we can store pointers-to-Buffer
+ // in containers and have the containers call the right subclass destructor.
+ virtual ~GenericBuffer() {}
+
+ const ArrayDataType type;
+
+ protected:
+ // Constructor used by subclasses for specific ArrayDataType's.
+ explicit GenericBuffer(ArrayDataType t) : type(t) {}
+};
+
+// Type-specific buffer, containing type-specific storage.
+template <ArrayDataType A>
+struct Buffer : GenericBuffer {
+ Buffer() : GenericBuffer(A) {}
+
+ std::vector<DataType<A>> data;
+};
+
+// Base class for all operator classes.
+struct Operator {
+ // Non-default-constructible: only OperatorType-specific subclass
+ // objects may be constructed.
+ Operator() = delete;
+ // Non-copyable-or-movable: we should only store pointers-to-Operator
+ // in containers, not Operators themselves, so there should be no
+ // copy or move.
+ Operator(const Operator&) = delete;
+ Operator(const Operator&&) = delete;
+
+ // We need a virtual destructor so we can store pointers-to-Operator
+ // in containers and have the containers call the right subclass destructor.
+ virtual ~Operator() {}
+
+ // The specific type of operator. Corresponds 1:1 to subclasses.
+ const OperatorType type;
+
+ // The activation function that may be fused into this operator,
+ // or None if no activation function is fused.
+ FusedActivationFunctionType fused_activation_function;
+
+ // Input arrays: either activation arrays or constant array parameters.
+ // We refer to them by their name, not by their address; the mapping of
+ // names to addresses is given by the Model, which owns both Operator's and
+ // Array's. Thus, an Operator on its own doesn't contain much information,
+ // it is meant to be used in conjunction with the Model that owns it.
+ std::vector<string> inputs;
+
+ // Output activation arrays. Same comments as for inputs apply here too.
+ std::vector<string> outputs;
+
+ // If true, the array has more outputs than are listed in the 'outputs'
+ // member. These need to be resolved by some graph transformation.
+ // This flag is only here to indicate that an operator should not be
+ // discarded as unused, even if from its 'outputs' member alone it
+ // looks unused.
+ bool unresolved_outputs = false;
+
+ protected:
+ // Constructor used by subclasses for specific OperatorType's.
+ explicit Operator(OperatorType t)
+ : type(t),
+ fused_activation_function(FusedActivationFunctionType::kNone) {}
+};
+
+// Padding types for Conv-like operators. This is how padding is typically
+// specified in model files. But for inference, we will need to resolve this
+// to a FixedPadding, see below.
+enum class PaddingType { kNone, kSame, kValid };
+
+// Padding as resolved for a specific layer shape, as needed for inference.
+// For a given layer shape, a given padding type will resolve to a choice of
+// a number of padding rows and columns, which we call the padding height and
+// width respectively.
+struct FixedPadding {
+ int width = 0;
+ int height = 0;
+};
+
+// "Universal" padding struct containing both a generic PaddingType (as
+// represented in a model file), and a FixedPadding (as needed for inference).
+// The latter is resolved during the PropagateFixedSizes pass.
+struct Padding {
+ FixedPadding& GetOrCreateFixedPadding() {
+ if (!fixed) {
+ FixedPadding* ptr = new FixedPadding;
+ fixed = std::unique_ptr<FixedPadding>(ptr);
+ }
+ return *fixed;
+ }
+
+ Padding() : type(PaddingType::kNone) {}
+ PaddingType type;
+ std::unique_ptr<FixedPadding> fixed;
+};
+
+// "Convolutional" layer, as represented in model files.
+//
+// Inputs:
+// inputs[0]: required: the input activations array
+// inputs[1]: required: the Conv weights
+// inputs[2]: optional: the bias vector, specifying the biases for each output
+// channel.
+//
+// Outputs:
+// outputs[0]: required: the output activations array
+// outputs[1]: optional: the intermediate array of im2col-replicated input
+// activations. Present when targeting implementations
+// of Conv layers as Im2col+GEMM.
+//
+// TensorFlow equivalent: Conv2D
+struct ConvOperator : Operator {
+ ConvOperator() : Operator(OperatorType::kConv) {}
+ Padding padding;
+ int stride_width = 0;
+ int stride_height = 0;
+};
+
+// Depthwise-separable convolution operator.
+//
+// Inputs:
+// inputs[0]: required: the input activations array
+// inputs[1]: required: the DepthwiseConv weights
+// inputs[2]: optional: the bias vector, specifying the biases for each output
+// channel.
+//
+// TensorFlow equivalent: DepthwiseConv2dNative
+struct DepthwiseConvOperator : Operator {
+ DepthwiseConvOperator() : Operator(OperatorType::kDepthwiseConv) {}
+ Padding padding;
+ int stride_height = 0;
+ int stride_width = 0;
+ int depth_multiplier = 0;
+};
+
+// Depth-to-space transform operator.
+//
+// Inputs:
+// inputs[0]: required: the input activations array
+//
+// TensorFlow equivalent: DepthToSpace
+struct DepthToSpaceOperator : Operator {
+ DepthToSpaceOperator() : Operator(OperatorType::kDepthToSpace) {}
+ int block_size = 0;
+};
+
+// Space-to-depth transform operator.
+//
+// Inputs:
+// inputs[0]: required: the input activations array
+//
+// TensorFlow equivalent: SpaceToDepth
+struct SpaceToDepthOperator : Operator {
+ SpaceToDepthOperator() : Operator(OperatorType::kSpaceToDepth) {}
+ int block_size = 0;
+};
+
+// Fully-connected operator.
+//
+// Inputs:
+// inputs[0]: required: the input activations array
+// inputs[1]: required: the FullyConnected weights
+// inputs[2]: optional: the bias vector, specifying the biases for each output
+// channel.
+//
+// TensorFlow equivalent: a pair consisting of a Reshape node reshaping the
+// input activations as a matrix, followed by a MatMul node.
+struct FullyConnectedOperator : Operator {
+ FullyConnectedOperator() : Operator(OperatorType::kFullyConnected) {}
+};
+
+// Dequantization operator, converting a quantized array of integers with
+// quantization parameters specifying how these integers correspond to real
+// numbers
+// (see QuantizationParams) to an output activations array of floating-point
+// values.
+//
+// In floating-point image models, there is typically a Dequantization operator
+// at the very beginning, converting the input image RGB data, consisting of
+// uint8 integer values, to floating-point input activations. That is where
+// image model parameters such as "mean_value" and "std_value" are typically
+// handled.
+//
+// This is the only operator type that converts from quantized to
+// floating-point,
+// and there is at the moment no operator type at all to convert from
+// floating-point
+// to quantized. Every other operator does either float->float or
+// quantized->quantized.
+//
+// Inputs:
+// inputs[0]: required: the input quantized activations array
+//
+// TensorFlow equivalent: Dequantize
+struct DequantizeOperator : Operator {
+ DequantizeOperator() : Operator(OperatorType::kDequantize) {}
+};
+
+// Batch-normalization operator.
+//
+// We only support batch-normalization using pre-learned moments, so this is
+// just
+// computing (input - mean) * multiplier + offset. As such, this can be
+// expressed as a combination of Add and Mul nodes, and indeed this is how
+// we break it down during tooling for the purpose of fusing it into
+// other operators.
+//
+// Inputs:
+// inputs[0]: required: the input activations array
+// inputs[1]: required: the learned mean array
+// inputs[2]: required: the learned multiplier array
+// inputs[3]: required: the learned offset array
+//
+// TensorFlow equivalent: a combination of Add and Mul nodes
+struct BatchNormalizationOperator : Operator {
+ BatchNormalizationOperator()
+ : Operator(OperatorType::kBatchNormalization),
+ global_normalization(false) {}
+ bool global_normalization;
+};
+
+// L2-normalization operator.
+//
+// Inputs:
+// inputs[0]: required: the input activations array
+//
+// TensorFlow equivalent: none. In TensorFlow, L2 normalization is implemented
+// by a sub-graph of operators implementing L2-normalization
+// from lower-level arithmetic nodes; during tooling, we identify such
+// sub-graphs
+// and replace them by L2NormalizationOperator's. See IdentifyL2Normalization.
+struct L2NormalizationOperator : Operator {
+ L2NormalizationOperator() : Operator(OperatorType::kL2Normalization) {}
+};
+
+// LSTM Cell operator.
+//
+// Inputs:
+// inputs[0]: required: the input data array
+// inputs[1]: required: the previous output activations array
+// inputs[2]: required: the learned weights array
+// inputs[3]: required: the learned biases array
+// inputs[4]: required: the previous output state
+// outputs[0]: required: the output activations array
+// outputs[1]: required: the new state array
+//
+// TensorFlow equivalent: none. In TensorFlow, an LSTM is implemented
+// with a sub-graph of lower-level arithmetic nodes; during tooling, we identify
+// such sub-graphs and replace them with LstmCells. See IdentifyLstmCell().
+struct LstmCellOperator : Operator {
+ enum Inputs {
+ DATA_INPUT = 0,
+ PREV_ACTIV_INPUT = 1,
+ WEIGHTS_INPUT = 2,
+ BIASES_INPUT = 3,
+ PREV_STATE_INPUT = 4,
+ NUM_INPUTS = 5
+ };
+ enum Outputs {
+ ACTIV_OUTPUT = 0,
+ STATE_OUTPUT = 1,
+ CONCAT_TEMP = 2,
+ ACTIV_TEMP = 3,
+ NUM_OUTPUTS = 4
+ };
+ LstmCellOperator() : Operator(OperatorType::kLstmCell) {}
+};
+
+// Element-wise multiplication operator.
+//
+// Inputs:
+// inputs[0]: required: the left-hand side array
+// inputs[1]: required: the right-hand side array
+//
+// TensorFlow equivalent: Mul
+struct MulOperator : Operator {
+ MulOperator() : Operator(OperatorType::kMul) {}
+};
+
+// Element-wise Relu operator:
+// x -> max(0, x)
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Relu
+struct ReluOperator : Operator {
+ ReluOperator() : Operator(OperatorType::kRelu) {}
+};
+
+// Element-wise Relu1 operator:
+// x -> min(max(x, -1), 1)
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: none. We can construct the operator with Minimum
+// and Maximum operations
+struct Relu1Operator : Operator {
+ Relu1Operator() : Operator(OperatorType::kRelu1) {}
+};
+
+// Element-wise Relu6 operator:
+// x -> max(0, min(6, x))
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Relu6
+struct Relu6Operator : Operator {
+ Relu6Operator() : Operator(OperatorType::kRelu6) {}
+};
+
+// Element-wise Logistic operator:
+// x -> Logistic(x) = 1 / (1 + exp(-x))
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Sigmoid
+struct LogisticOperator : Operator {
+ LogisticOperator() : Operator(OperatorType::kLogistic) {}
+};
+
+// Element-wise Tanh operator:
+// x -> Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Tanh
+struct TanhOperator : Operator {
+ TanhOperator() : Operator(OperatorType::kTanh) {}
+};
+
+// Element-wise addition operator.
+//
+// Inputs:
+// inputs[0]: required: the left-hand side array
+// inputs[1]: required: the right-hand side array
+//
+// TensorFlow equivalent: Add
+struct AddOperator : Operator {
+ AddOperator() : Operator(OperatorType::kAdd) {}
+};
+
+// Concatenation operator: concatenates its inputs
+// along the concat_dim dimension.
+//
+// Inputs: this operator accepts any number >= 1 of inputs.
+// inputs[i]: the i-th array to concatenate.
+//
+// TensorFlow equivalent: Concat.
+struct ConcatenationOperator : Operator {
+ ConcatenationOperator() : Operator(OperatorType::kConcatenation) {}
+ int concat_dim = 0;
+};
+
+// Reordering dimensions. Used only during tooling to transform graphs from
+// the TensorFlow format.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: none. This is only useful to convert between formats.
+struct ReorderAxesOperator : Operator {
+ ReorderAxesOperator() : Operator(OperatorType::kReorderAxes) {}
+ AxesOrder input_axes_order;
+ AxesOrder output_axes_order;
+};
+
+// Average-pooling operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: AveragePool
+struct AveragePoolOperator : Operator {
+ AveragePoolOperator() : Operator(OperatorType::kAveragePool) {}
+ Padding padding;
+ int stride_height = 0;
+ int stride_width = 0;
+ int kheight = 0;
+ int kwidth = 0;
+};
+
+// Local response normalization operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: LRN
+struct LocalResponseNormalizationOperator : Operator {
+ LocalResponseNormalizationOperator()
+ : Operator(OperatorType::kLocalResponseNormalization) {}
+
+ int range = 0;
+ float bias = 0.f;
+ float alpha = 0.f;
+ float beta = 0.f;
+};
+
+// Max-pooling operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: MaxPool
+struct MaxPoolOperator : Operator {
+ MaxPoolOperator() : Operator(OperatorType::kMaxPool) {}
+ Padding padding;
+ int stride_height = 0;
+ int stride_width = 0;
+ int kheight = 0;
+ int kwidth = 0;
+};
+
+// L2-pooling operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: none. Can be shimmed by squaring+avgpool+sqrt.
+struct L2PoolOperator : Operator {
+ L2PoolOperator() : Operator(OperatorType::kL2Pool) {}
+ Padding padding;
+ int stride_height = 0;
+ int stride_width = 0;
+ int kheight = 0;
+ int kwidth = 0;
+};
+
+// The expected [min, max] range of values in a given array.
+// Used for quantization only.
+// This information typically comes from special nodes found in quantized
+// models,
+// see FakeQuantOperator, and is used during quantization to resolve
+// actual quantization parameters (see QuantizationParams).
+struct MinMax {
+ double min = 0.;
+ double max = 0.;
+};
+
+inline bool operator==(const MinMax& m1, const MinMax& m2) {
+ return m1.min == m2.min && m1.max == m2.max;
+}
+
+// Fake-quantization operator. This does two things:
+// - Annotate its input and output arrays with MinMax information,
+// - Arithmetic-wise, this operator rounds incoming activation values
+// to the nearest representable value on the scale of 256
+// values from the min to the max value dictated by its MinMax info.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: optional: the 'min' value, if it has not yet been resolved
+// to a constant.
+// inputs[2]: optional: the 'max' value, if it has not yet been resolved
+// to a constant.
+//
+// TensorFlow equivalent: FakeQuantWithMinMaxVars, FakeQuantWithMinMaxArgs.
+struct FakeQuantOperator : Operator {
+ FakeQuantOperator() : Operator(OperatorType::kFakeQuant) {}
+ std::unique_ptr<MinMax> minmax;
+};
+
+// Element-wise division operator.
+//
+// Inputs:
+// inputs[0]: required: the left-hand side array
+// inputs[1]: required: the right-hand side array
+//
+// TensorFlow equivalent: Div
+struct DivOperator : Operator {
+ DivOperator() : Operator(OperatorType::kDiv) {}
+};
+
+// Element-wise identity (x->x) operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Identity
+struct TensorFlowIdentityOperator : Operator {
+ TensorFlowIdentityOperator() : Operator(OperatorType::kTensorFlowIdentity) {}
+};
+
+// General matrix multiplication operator. We don't want to support general
+// matrix multiplication at inference time, so we resolve it during tooling
+// to more specific operator types, namely, FullyConnected.
+//
+// Inputs:
+// inputs[0]: required: the left-hand side matrix
+// inputs[1]: required: the right-hand side matrix
+//
+// TensorFlow equivalent: MatMul
+struct TensorFlowMatMulOperator : Operator {
+ TensorFlowMatMulOperator() : Operator(OperatorType::kTensorFlowMatMul) {}
+};
+
+// Padding operator. Pads a tensor with zeros.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: the padding array
+//
+// This operation pads a `input` with zeros according to the `paddings` you
+// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the
+// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
+// how many zeros to add before the contents of `input` in that dimension, and
+// `paddings[D, 1]` indicates how many zeros to add after the contents of
+// `input` in that dimension.
+//
+// TensorFlow equivalent: Pad
+struct PadOperator : Operator {
+ PadOperator() : Operator(OperatorType::kPad) {}
+
+ std::vector<int> left_padding;
+ std::vector<int> right_padding;
+};
+
+// Strided slice operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: StridedSlice
+struct StridedSliceOperator : Operator {
+ StridedSliceOperator() : Operator(OperatorType::kStridedSlice) {}
+
+ std::vector<int> start_indices;
+ std::vector<int> stop_indices;
+ std::vector<int> strides;
+
+ int begin_mask;
+ int ellipsis_mask;
+ int end_mask;
+ int new_axis_mask;
+ int shrink_axis_mask;
+};
+
+// Reshaping operator, reshaping its input array to a two-dimensional shape
+// (a "matrix"). This is used in the TensorFlow format, in conjunction with
+// MatMul nodes, to implement fully-connected layers.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Reshape --- except that we only support a special case
+// here, where the output shape is a matrix (2D) shape.
+struct TensorFlowReshapeOperator : Operator {
+ TensorFlowReshapeOperator() : Operator(OperatorType::kTensorFlowReshape) {}
+ std::vector<int> shape;
+};
+
+// Removes dimensions of size 1 from the shape of a tensor.
+// https://www.tensorflow.org/api_docs/python/tf/squeeze
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Squeeze
+struct SqueezeOperator : Operator {
+ SqueezeOperator() : Operator(OperatorType::kSqueeze) {}
+
+ std::vector<int> squeeze_dims;
+};
+
+// Element-wise reciprocal-square-root (x^-0.5) operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Rsqrt
+struct TensorFlowRsqrtOperator : Operator {
+ TensorFlowRsqrtOperator() : Operator(OperatorType::kTensorFlowRsqrt) {}
+};
+
+// Shape operator. Extracts the shape of the tensor.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// This operation outputs a 1-D integer tensor representing the shape of
+// the input.
+//
+// TensorFlow equivalent: Shape. We currently assume that the output is int32
+// and not int64. The output type could be stored herein.
+struct TensorFlowShapeOperator : Operator {
+ TensorFlowShapeOperator() : Operator(OperatorType::kTensorFlowShape) {}
+};
+
+// Element-wise square-root (x^0.5) operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Sqrt
+struct TensorFlowSqrtOperator : Operator {
+ TensorFlowSqrtOperator() : Operator(OperatorType::kTensorFlowSqrt) {}
+};
+
+// Element-wise square (x*x) operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Square
+struct TensorFlowSquareOperator : Operator {
+ TensorFlowSquareOperator() : Operator(OperatorType::kTensorFlowSquare) {}
+};
+
+// Element-wise subtraction operator.
+//
+// Inputs:
+// inputs[0]: required: the left-hand side array
+// inputs[1]: required: the right-hand side array
+//
+// TensorFlow equivalent: Sub
+struct SubOperator : Operator {
+ SubOperator() : Operator(OperatorType::kSub) {}
+};
+
+// Global sum reduction: computes the sum of all of entries in the input array.
+// Thus the output is "0-dimensional": it consists of a single scalar value.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Sum --- except that we only support the special case
+// of global reduction across all dimensions.
+struct TensorFlowSumOperator : Operator {
+ TensorFlowSumOperator() : Operator(OperatorType::kTensorFlowSum) {}
+};
+
+// TensorFlow Tile equivalent. Refer to TensorFlow documentation for details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+struct TensorFlowTileOperator : Operator {
+ TensorFlowTileOperator() : Operator(OperatorType::kTensorFlowTile) {}
+};
+
+// TensorFlow Slice equivalent. Refer to TensorFlow documentation for details.
+struct SliceOperator : Operator {
+ SliceOperator() : Operator(OperatorType::kSlice) {}
+
+ std::vector<int> begin;
+ std::vector<int> size;
+};
+
+// TensorFlow Split equivalent. Refer to TensorFlow documentation for details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+struct TensorFlowSplitOperator : Operator {
+ TensorFlowSplitOperator() : Operator(OperatorType::kTensorFlowSplit) {}
+ int num_split = 0;
+};
+
+// TensorFlow Concat equivalent. Refer to TensorFlow documentation for details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Concretely, once the concat dim becomes known, if it is the depth
+// dimension then we can change this op into a DepthConcatenation op.
+// Otherwise, we hope for some other graph transformation to drop this node.
+struct TensorFlowConcatOperator : Operator {
+ TensorFlowConcatOperator() : Operator(OperatorType::kTensorFlowConcat) {}
+};
+
+// TensorFlow ConcatV2 equivalent. Refer to TensorFlow documentation for
+// details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Concretely, once the concat dim becomes known, if it is the depth
+// dimension then we can change this op into a DepthConcatenation op.
+// Otherwise, we hope for some other graph transformation to drop this node.
+struct TensorFlowConcatV2Operator : Operator {
+ TensorFlowConcatV2Operator() : Operator(OperatorType::kTensorFlowConcatV2) {}
+};
+
+// TensorFlow Merge equivalent. Refer to TensorFlow documentation for details.
+//
+// Inputs: this operator accepts any number >= 1 of inputs.
+// inputs[i]: the i-th array to merge.
+//
+// It is expected that graph transformations will drop all but exactly one
+// of the inputs, at which point the Merge node will be equivalent to an
+// Identity node forwarding the remaining input.
+//
+// Note: We do not currently support runtime control flow: we only support
+// control flow that can be resolved at tooling time (independently of input
+// activations).
+struct TensorFlowMergeOperator : Operator {
+ TensorFlowMergeOperator() : Operator(OperatorType::kTensorFlowMerge) {}
+};
+
+// TensorFlow Switch equivalent. Refer to TensorFlow documentation for details.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: the boolean predicate, given as an array of size 1
+// and of type kBool, will determine which output gets selected.
+//
+// Outputs: a TensorFlow Switch node always has exactly two outputs. Depending
+// on the boolean value that the input predicate resolves to (see note below),
+// one or the other of the outputs will be 'selected': the input array will be
+// forwarded to the 'selected output' as if by a Identity node, while the other
+// output will be discarded, and any graph edge connecting that discarded output
+// will be dropped. The rule for selecting outputs is as follows:
+// outputs[0] will be selected if the input predicate resolves to 'true'.
+// outputs[1] will be selected if the input predicate resolves to 'false'.
+//
+// Note: We do not currently support runtime control flow: we only support
+// control flow that can be resolved at tooling time (independently of input
+// activations).
+struct TensorFlowSwitchOperator : Operator {
+ TensorFlowSwitchOperator() : Operator(OperatorType::kTensorFlowSwitch) {}
+};
+
+// TensorFlow All equivalent. Refer to TensorFlow documentation for details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, this is only used as an input to an Assert node, so can be
+// removed as an unused node as we drop Assert nodes.
+struct TensorFlowAllOperator : Operator {
+ TensorFlowAllOperator() : Operator(OperatorType::kTensorFlowAll) {}
+};
+
+// TensorFlow Assert equivalent. Refer to TensorFlow documentation for details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, we just drop Assert nodes.
+struct TensorFlowAssertOperator : Operator {
+ TensorFlowAssertOperator() : Operator(OperatorType::kTensorFlowAssert) {}
+};
+
+// TensorFlow Less equivalent. Refer to TensorFlow documentation for details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, this is only used as an input to an Assert node, so can be
+// removed as an unused node as we drop Assert nodes.
+struct TensorFlowLessOperator : Operator {
+ TensorFlowLessOperator() : Operator(OperatorType::kTensorFlowLess) {}
+};
+
+// TensorFlow LessEqual equivalent. Refer to TensorFlow documentation for
+// details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, this is only used as an input to an Assert node, so can be
+// removed as an unused node as we drop Assert nodes.
+struct TensorFlowLessEqualOperator : Operator {
+ TensorFlowLessEqualOperator()
+ : Operator(OperatorType::kTensorFlowLessEqual) {}
+};
+
+// TensorFlow Less equivalent. Refer to TensorFlow documentation for details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, this is only used as an input to an Assert node, so can be
+// removed as an unused node as we drop Assert nodes.
+struct TensorFlowGreaterOperator : Operator {
+ TensorFlowGreaterOperator() : Operator(OperatorType::kTensorFlowGreater) {}
+};
+
+// TensorFlow GreaterEqual equivalent. Refer to TensorFlow documentation for
+// details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, this is only used as an input to an Assert node, so can be
+// removed as an unused node as we drop Assert nodes.
+struct TensorFlowGreaterEqualOperator : Operator {
+ TensorFlowGreaterEqualOperator()
+ : Operator(OperatorType::kTensorFlowGreaterEqual) {}
+};
+
+// Global max reduction: computes the max of all of entries in the input array.
+// Thus the output is "0-dimensional": it consists of a single scalar value.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Max --- except that we only support the special case
+// of global reduction across all dimensions.
+struct TensorFlowMaxOperator : Operator {
+ TensorFlowMaxOperator() : Operator(OperatorType::kTensorFlowMax) {}
+};
+
+// Global min reduction: computes the min of all of entries in the input array.
+// Thus the output is "0-dimensional": it consists of a single scalar value.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Min --- except that we only support the special case
+// of global reduction across all dimensions.
+struct TensorFlowMinOperator : Operator {
+ TensorFlowMinOperator() : Operator(OperatorType::kTensorFlowMin) {}
+};
+
+// Element-wise maximum operator. Currently it only supports scalar as
+// the second operand.
+//
+// Inputs:
+// inputs[0]: required: the left-hand side array
+// inputs[1]: required: the right-hand side array
+//
+// TensorFlow equivalent: Maximum
+struct TensorFlowMaximumOperator : Operator {
+ TensorFlowMaximumOperator() : Operator(OperatorType::kTensorFlowMaximum) {}
+};
+
+// Element-wise minimum operator. Currently it only supports scalar as
+// the second operand.
+//
+// Inputs:
+// inputs[0]: required: the left-hand side array
+// inputs[1]: required: the right-hand side array
+//
+// TensorFlow equivalent: Minimum
+struct TensorFlowMinimumOperator : Operator {
+ TensorFlowMinimumOperator() : Operator(OperatorType::kTensorFlowMinimum) {}
+};
+
+// General TF operation, unsupported by tf.mini. Expected to be dropped by
+// graph transformations.
+struct TensorFlowUnsupportedOperator : Operator {
+ TensorFlowUnsupportedOperator()
+ : Operator(OperatorType::kTensorFlowUnsupported) {}
+
+ // The original TF operation type. Used for diagnostic purposes.
+ string tensorflow_op;
+ // A serialized tensorflow::NodeDef string.
+ string tensorflow_node_def;
+ // A boolean indicating if the unsupported op should be treated as quantized.
+ bool quantized = false;
+ // Output data types
+ std::vector<ArrayDataType> output_data_types;
+};
+
+// Softmax activation function.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Softmax
+struct SoftmaxOperator : Operator {
+ SoftmaxOperator() : Operator(OperatorType::kSoftmax) {}
+ float beta = 0.f;
+};
+
+// Cast operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Cast
+struct CastOperator : Operator {
+ CastOperator() : Operator(OperatorType::kCast) {}
+ ArrayDataType src_data_type = ArrayDataType::kNone;
+ ArrayDataType dst_data_type = ArrayDataType::kNone;
+};
+
+// Floor operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Floor
+struct FloorOperator : Operator {
+ FloorOperator() : Operator(OperatorType::kFloor) {}
+};
+
+// Gather operator. It gathers slices from params according to indices.
+// Only 1-D indices are supported at the moment.
+//
+// Inputs:
+// inputs[0]: required: the params array
+// inputs[1]: required: the indices to gather
+//
+// TensorFlow equivalent: Gather
+struct GatherOperator : Operator {
+ GatherOperator() : Operator(OperatorType::kGather) {}
+ int input_rank;
+};
+
+// ResizeBilinear operator. It resizes input images with bilinear interpolation.
+// It does not support align_corners at the moment.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: the new image size
+//
+// TensorFlow equivalent: ResizeBilinear
+struct ResizeBilinearOperator : Operator {
+ ResizeBilinearOperator() : Operator(OperatorType::kResizeBilinear) {}
+};
+
+// SpaceToBatchND operator. It divides spatial dimensions into a grid of
+// blocks and interleaves these blocks with the batch dimension. Currently,
+// only 2-d blocks are supported.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: the block shape
+// inputs[2]: required: the paddings
+//
+// TensorFlow equivalent: SpaceToBatchND
+struct SpaceToBatchNDOperator : Operator {
+ SpaceToBatchNDOperator() : Operator(OperatorType::kSpaceToBatchND) {}
+};
+
+// BatchToSpaceND operator. Rearranges data from batch into blocks of
+// spatial data. Currently, only 2-d blocks are supported. Cropping is not
+// supported, either, and the crops array should be all zero.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: the block shape
+// inputs[2]: required: the crops
+//
+// TensorFlow equivalent: BatchToSpaceND
+struct BatchToSpaceNDOperator : Operator {
+ BatchToSpaceNDOperator() : Operator(OperatorType::kBatchToSpaceND) {}
+};
+
+// Mean operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Mean
+struct MeanOperator : Operator {
+ MeanOperator() : Operator(OperatorType::kMean) {}
+
+ std::vector<int> reduction_indices;
+};
+
+// Svdf operator:
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: weights_feature
+// inputs[2]: required: weights_time
+// inputs[3]: optional: bias
+struct SvdfOperator : Operator {
+ SvdfOperator() : Operator(OperatorType::kSvdf) {}
+ int rank;
+};
+
+// Alloc's are used for transient arrays only. An Alloc specifies which interval
+// of the "transient_data" workspace buffer passed to inference functions, is to
+// be used for the transient array at hand. The 'start' and 'end' values are
+// offsets from the start of the workspace buffer, expressed in bytes.
+struct Alloc {
+ int start = 0;
+ int end = 0;
+};
+
+inline bool operator<(const Alloc& a, const Alloc& b) {
+ return a.start < b.start;
+}
+
+// Quantization parameters, determining the mapping of quantized values
+// to real values (i.e. determining how quantized values are mathematically
+// interpreted).
+//
+// The correspondence is as follows:
+//
+// real_value = scale * (quantized_value - zero_point);
+//
+// In other words, zero_point designates which quantized value corresponds to
+// the real 0 value, and scale designates the difference between the real values
+// corresponding to consecutive quantized values differing by 1.
+struct QuantizationParams {
+ int32 zero_point = 0;
+ double scale = 0.;
+};
+
+class Shape {
+ public:
+ // For Shape, we stick to half-way encapsulation for now:
+ // we hide the raw dims_ member, but expose it raw by accessors
+ // because from some brainstorming, it's not at all easy to
+ // anticipate which flavor of more hermetic encapsulation would
+ // actually buy us future-proof-ness without being needlessly
+ // cumbersome.
+ Shape() {}
+ Shape(std::initializer_list<int> dim_list) : dims_(dim_list) {}
+
+ void ReplaceDims(std::initializer_list<int> dim_list) {
+ dims_ = std::vector<int>(dim_list);
+ }
+
+ const std::vector<int>& dims() const { return dims_; }
+ std::vector<int>* mutable_dims() { return &dims_; }
+ const int dimensions_count() const { return dims_.size(); }
+
+ // We still have that one convenience accessor to avoid
+ // the awkward double bracket issue: shape.dims()[i].
+ int dims(int i) const { return dims_[i]; }
+
+ bool operator==(const Shape& comp) const {
+ return (this->dims_ == comp.dims());
+ }
+
+ bool operator!=(const Shape& comp) const { return !((*this) == comp); }
+
+ private:
+ std::vector<int> dims_;
+};
+
+// Array represents an array (either a constant parameter array or an
+// activations array) in a Model.
+struct Array {
+ template <ArrayDataType A>
+ const Buffer<A>& GetBuffer() const {
+ DCHECK(buffer);
+ DCHECK(buffer->type == A);
+ return *static_cast<const Buffer<A>*>(buffer.get());
+ }
+ template <ArrayDataType A>
+ Buffer<A>& GetMutableBuffer() {
+ if (!buffer) {
+ Buffer<A>* ptr = new Buffer<A>;
+ buffer = std::unique_ptr<GenericBuffer>(ptr);
+ }
+ DCHECK(buffer);
+ DCHECK(buffer->type == A);
+ return *static_cast<Buffer<A>*>(buffer.get());
+ }
+ Alloc& GetOrCreateAlloc() {
+ if (!alloc) {
+ alloc = std::unique_ptr<Alloc>(new Alloc);
+ }
+ return *alloc;
+ }
+ MinMax& GetOrCreateMinMax() {
+ if (!minmax) {
+ minmax = std::unique_ptr<MinMax>(new MinMax);
+ }
+ return *minmax;
+ }
+ MinMax& GetMinMax() const {
+ DCHECK(minmax);
+ return *minmax;
+ }
+ QuantizationParams& GetOrCreateQuantizationParams() {
+ if (!quantization_params) {
+ quantization_params =
+ std::unique_ptr<QuantizationParams>(new QuantizationParams);
+ }
+ return *quantization_params;
+ }
+ QuantizationParams& GetQuantizationParams() const {
+ DCHECK(quantization_params);
+ return *quantization_params;
+ }
+
+ // The data type of the actual elements of this array, that is:
+ // - If there is a buffer (see 'buffer' member), it must be of the same
+ // type.
+ // - If there is no buffer, meaning that this is a runtime (i.e. activations)
+ // array, then this specifies the type of elements that there will be
+ // at runtime.
+ //
+ // Note that this only specifies the storage type of elements; this does
+ // not specify whether these are to be treated as 'real' or 'quantized'
+ // values.
+ // That is decided by whether the 'quantization_params' member is null.
+ ArrayDataType data_type = ArrayDataType::kNone;
+ // The final value that data_type should have at the end of graph
+ // transformations
+ ArrayDataType final_data_type = ArrayDataType::kNone;
+ // The dimensions of this array --- this specifies both sizes and strides
+ // (the storage layout).
+ //
+ // Issues with shape handling that remain include:
+ // - No way to distinguish between 0-dimensional dims and missing dims.
+ // - No way to describe dims that may be runtime-variable.
+ // - Addressing of dims by integer index differs in different graph formats
+ // (TensorFlow vs. other frameworks vs. what we have informally grown
+ // within toco).
+ // This is currently quite messy; see ReorderAxesOperator which is how we
+ // bridge some of these discrepancies at the moment. This is overdue for
+ // a redesign; I'm thinking that it would be nice to have more flexible
+ // dims that allow mapping 1:1, cleanly, dims as they are in various
+ // formats,
+ // then explicitly convert between different conventions.
+
+ // Proto-style accessors
+ bool has_shape() const { return array_shape != nullptr; }
+ const Shape& shape() const {
+ CHECK(has_shape());
+ return *array_shape;
+ }
+ Shape* mutable_shape() {
+ if (!array_shape) {
+ array_shape.reset(new Shape);
+ }
+ return array_shape.get();
+ }
+ void copy_shape(const Shape& src_shape) { *mutable_shape() = src_shape; }
+ void clear_shape() { array_shape = nullptr; }
+
+ // The constant buffer backing this array. This is non-null if and only if
+ // this is a constant parameter array. Conversely, this is null for
+ // activations arrays.
+ //
+ // Note that this buffer is pure storage. In the case of quantized values,
+ // it only stores the quantized values, it does not know by itself about the
+ // quantization parameters necessary to interprete these values, that is
+ // in the separate 'quantization_params' field. In fact, this 'buffer' field
+ // does no even know whether values are quantized. It only has a data_type,
+ // which must equal the 'data_type' member here, and which only describes
+ // the storage type of element, does not tell whether they are quantized i.e.
+ // whether they are to be interpreted with quantization_params.
+ std::unique_ptr<GenericBuffer> buffer;
+ // Only for activation arrays (i.e. when 'buffer' is null).
+ // Only for code generation.
+ //
+ // Describes the allocation of this array within the workspace buffer
+ // allocated
+ // for all transient arrays.
+ std::unique_ptr<Alloc> alloc;
+ // Describes the [min, max] range of values
+ // to be assumed when determining quantization_params.
+ //
+ // Only used for quantization. In fact, only used for determining
+ // quantization_params.
+ //
+ // Used for both constant arrays (those having a 'buffer') and non-constant
+ // arrays (activations). Indeed, it is important to use the same min-max range
+ // as was used during training, even if that min-max range is slightly wrong
+ // w.r.t. actual buffer elements. Doing otherwise would defeat the point of
+ // re-training for quantization.
+ std::unique_ptr<MinMax> minmax;
+ // Quantization parameters. The non-null-ness of this pointer is what
+ // defines whether this array is quantized or not.
+ //
+ // If this is non-null, then these quantization parameters are to be used
+ // to assign a meaning as real numbers to the elements of this array.
+ std::unique_ptr<QuantizationParams> quantization_params;
+
+ private:
+ std::unique_ptr<Shape> array_shape;
+};
+
+// Our Model struct, represents an entire model (our "top-level" struct).
+// Owns everything.
+struct Model {
+ Array& GetArray(const string& name) const {
+ DCHECK(arrays.count(name));
+ return *arrays.at(name);
+ }
+ Array& GetOrCreateArray(const string& name) {
+ if (!arrays.count(name)) {
+ Array* ptr = new Array;
+ arrays[name] = std::unique_ptr<Array>(ptr);
+ }
+ Array& result = GetArray(name);
+ return result;
+ }
+
+ // The list of operators. Notice how it's a list of unique_ptr's, implying
+ // that the Model is what owns Operator's and keeps them alive.
+ std::vector<std::unique_ptr<Operator>> operators;
+ // The associative array mapping names to Array's.
+ // Notice how it's a container of unique_ptr's, implying
+ // that the Model is what owns Array's and keeps them alive.
+ // The Operator's refer to these Array's by their name strings, not by their
+ // addresses. See Operator::inputs, Operator::outputs.
+ std::unordered_map<string, std::unique_ptr<Array>> arrays;
+ // Generic flags, a place where we combine information passed to us via
+ // command-line parameters (e.g. --input_width=N) with information that
+ // we may or may not find in the input model file.
+ ModelFlags flags;
+ // For code-generation only: required size of the transient_data buffer
+ std::size_t transient_data_size = 0;
+ // For code-generation only: required alignment of the transient_data buffer
+ std::size_t transient_data_alignment = 0;
+};
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
new file mode 100644
index 0000000000..699c95753f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -0,0 +1,374 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h"
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/ascii.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/strip.h"
+#include "tensorflow/contrib/lite/toco/args.h"
+#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/command_line_flags.h"
+// "batch" flag only exists internally
+#ifdef PLATFORM_GOOGLE
+#include "base/commandlineflags.h"
+#endif
+
+namespace toco {
+
+bool ParseModelFlagsFromCommandLineFlags(
+ int* argc, char* argv[], string* msg,
+ ParsedModelFlags* parsed_model_flags_ptr) {
+ ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr;
+ using tensorflow::Flag;
+ std::vector<tensorflow::Flag> flags = {
+ Flag("input_array", parsed_flags.input_array.bind(),
+ parsed_flags.input_array.default_value(),
+ "Name of the input array. If not specified, will try to read "
+ "that information from the input file."),
+ Flag("input_arrays", parsed_flags.input_arrays.bind(),
+ parsed_flags.input_arrays.default_value(),
+ "Names of the output arrays, comma-separated. If not specified, "
+ "will try to read that information from the input file."),
+ Flag("output_array", parsed_flags.output_array.bind(),
+ parsed_flags.output_array.default_value(),
+ "Name of the output array, when specifying a unique output array. "
+ "If not specified, will try to read that information from the "
+ "input file."),
+ Flag("output_arrays", parsed_flags.output_arrays.bind(),
+ parsed_flags.output_arrays.default_value(),
+ "Names of the output arrays, comma-separated. "
+ "If not specified, will try to read "
+ "that information from the input file."),
+ Flag("input_shape", parsed_flags.input_shape.bind(),
+ parsed_flags.output_arrays.default_value(),
+ "Input array shape. For many models the shape takes the form "
+ "batch size, input array height, input array width, input array "
+ "depth."),
+ Flag("input_shapes", parsed_flags.input_shapes.bind(),
+ parsed_flags.input_shapes.default_value(),
+ "Shapes corresponding to --input_arrays, colon-separated. For "
+ "many models each shape takes the form batch size, input array "
+ "height, input array width, input array depth."),
+ Flag("mean_value", parsed_flags.mean_value.bind(),
+ parsed_flags.mean_value.default_value(),
+ "mean_value parameter for image models, used to compute input "
+ "activations from input pixel data."),
+ Flag("mean_values", parsed_flags.mean_values.bind(),
+ parsed_flags.mean_values.default_value(),
+ "mean_values parameter for image models, comma-separated list of "
+ "doubles, used to compute input activations from input pixel "
+ "data. Each entry in the list should match an entry in "
+ "--input_arrays."),
+ Flag("std_value", parsed_flags.std_value.bind(),
+ parsed_flags.std_value.default_value(),
+ "std_value parameter for image models, used to compute input "
+ "activations from input pixel data."),
+ Flag("std_values", parsed_flags.std_values.bind(),
+ parsed_flags.std_values.default_value(),
+ "std_value parameter for image models, comma-separated list of "
+ "doubles, used to compute input activations from input pixel "
+ "data. Each entry in the list should match an entry in "
+ "--input_arrays."),
+ Flag("variable_batch", parsed_flags.variable_batch.bind(),
+ parsed_flags.variable_batch.default_value(),
+ "If true, the model accepts an arbitrary batch size. Mutually "
+ "exclusive "
+ "with the 'batch' field: at most one of these two fields can be "
+ "set."),
+ Flag(
+ "drop_control_dependency",
+ parsed_flags.drop_control_dependency.bind(),
+ parsed_flags.drop_control_dependency.default_value(),
+ "If true, ignore control dependency requirements in input TensorFlow "
+ "GraphDef. Otherwise an error will be raised upon control dependency "
+ "inputs."),
+ Flag("rnn_states", parsed_flags.rnn_states.bind(),
+ parsed_flags.rnn_states.default_value(), ""),
+ Flag("model_checks", parsed_flags.model_checks.bind(),
+ parsed_flags.model_checks.default_value(),
+ "A list of model checks to be applied to verify the form of the "
+ "model. Applied after the graph transformations after import."),
+ Flag("graphviz_first_array", parsed_flags.graphviz_first_array.bind(),
+ parsed_flags.graphviz_first_array.default_value(),
+ "If set, defines the start of the sub-graph to be dumped to "
+ "GraphViz."),
+ Flag(
+ "graphviz_last_array", parsed_flags.graphviz_last_array.bind(),
+ parsed_flags.graphviz_last_array.default_value(),
+ "If set, defines the end of the sub-graph to be dumped to GraphViz."),
+ Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(),
+ parsed_flags.dump_graphviz.default_value(),
+ "Dump graphviz during LogDump call. If string is non-empty then "
+ "it defines path to dump, otherwise will skip dumping."),
+ Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(),
+ parsed_flags.dump_graphviz_video.default_value(),
+ "If true, will dump graphviz at each "
+ "graph transformation, which may be used to generate a video."),
+ };
+ bool asked_for_help =
+ *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
+ if (asked_for_help) {
+ *msg += tensorflow::Flags::Usage(argv[0], flags);
+ return false;
+ } else {
+ if (!tensorflow::Flags::Parse(argc, argv, flags)) return false;
+ }
+ auto& dump_options = *GraphVizDumpOptions::singleton();
+ dump_options.graphviz_first_array = parsed_flags.graphviz_first_array.value();
+ dump_options.graphviz_last_array = parsed_flags.graphviz_last_array.value();
+ dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value();
+ dump_options.dump_graphviz = parsed_flags.dump_graphviz.value();
+
+ return true;
+}
+
+void ReadModelFlagsFromCommandLineFlags(
+ const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) {
+ toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet");
+
+// "batch" flag only exists internally
+#ifdef PLATFORM_GOOGLE
+ CHECK(!((base::SpecifiedOnCommandLine("batch") &&
+ parsed_model_flags.variable_batch.specified())))
+ << "The --batch and --variable_batch flags are mutually exclusive.";
+#endif
+ CHECK(!(parsed_model_flags.output_array.specified() &&
+ parsed_model_flags.output_arrays.specified()))
+ << "The --output_array and --vs flags are mutually exclusive.";
+
+ if (parsed_model_flags.output_array.specified()) {
+ model_flags->add_output_arrays(parsed_model_flags.output_array.value());
+ }
+
+ if (parsed_model_flags.output_arrays.specified()) {
+ std::vector<string> output_arrays =
+ absl::StrSplit(parsed_model_flags.output_arrays.value(), ',');
+ for (const string& output_array : output_arrays) {
+ model_flags->add_output_arrays(output_array);
+ }
+ }
+
+ const bool uses_single_input_flags =
+ parsed_model_flags.input_array.specified() ||
+ parsed_model_flags.mean_value.specified() ||
+ parsed_model_flags.std_value.specified() ||
+ parsed_model_flags.input_shape.specified();
+
+ const bool uses_multi_input_flags =
+ parsed_model_flags.input_arrays.specified() ||
+ parsed_model_flags.mean_values.specified() ||
+ parsed_model_flags.std_values.specified() ||
+ parsed_model_flags.input_shapes.specified();
+
+ QCHECK(!(uses_single_input_flags && uses_multi_input_flags))
+ << "Use either the singular-form input flags (--input_array, "
+ "--input_shape, --mean_value, --std_value) or the plural form input "
+ "flags (--input_arrays, --input_shapes, --mean_values, --std_values), "
+ "but not both forms within the same command line.";
+
+ if (parsed_model_flags.input_array.specified()) {
+ QCHECK(uses_single_input_flags);
+ model_flags->add_input_arrays()->set_name(
+ parsed_model_flags.input_array.value());
+ }
+ if (parsed_model_flags.input_arrays.specified()) {
+ QCHECK(uses_multi_input_flags);
+ for (const auto& input_array :
+ absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) {
+ model_flags->add_input_arrays()->set_name(string(input_array));
+ }
+ }
+ if (parsed_model_flags.mean_value.specified()) {
+ QCHECK(uses_single_input_flags);
+ model_flags->mutable_input_arrays(0)->set_mean_value(
+ parsed_model_flags.mean_value.value());
+ }
+ if (parsed_model_flags.mean_values.specified()) {
+ QCHECK(uses_multi_input_flags);
+ std::vector<string> mean_values =
+ absl::StrSplit(parsed_model_flags.mean_values.value(), ',');
+ QCHECK(mean_values.size() == model_flags->input_arrays_size());
+ for (int i = 0; i < mean_values.size(); ++i) {
+ char* last = nullptr;
+ model_flags->mutable_input_arrays(i)->set_mean_value(
+ strtod(mean_values[i].data(), &last));
+ CHECK(last != mean_values[i].data());
+ }
+ }
+ if (parsed_model_flags.std_value.specified()) {
+ QCHECK(uses_single_input_flags);
+ model_flags->mutable_input_arrays(0)->set_std_value(
+ parsed_model_flags.std_value.value());
+ }
+ if (parsed_model_flags.std_values.specified()) {
+ QCHECK(uses_multi_input_flags);
+ std::vector<string> std_values =
+ absl::StrSplit(parsed_model_flags.std_values.value(), ',');
+ QCHECK(std_values.size() == model_flags->input_arrays_size());
+ for (int i = 0; i < std_values.size(); ++i) {
+ char* last = nullptr;
+ model_flags->mutable_input_arrays(i)->set_std_value(
+ strtod(std_values[i].data(), &last));
+ CHECK(last != std_values[i].data());
+ }
+ }
+ if (parsed_model_flags.input_shape.specified()) {
+ QCHECK(uses_single_input_flags);
+ if (model_flags->input_arrays().empty()) {
+ model_flags->add_input_arrays();
+ }
+ auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape();
+ shape->Clear();
+ const IntList& list = parsed_model_flags.input_shape.value();
+ for (auto& dim : list.elements) {
+ shape->Add(dim);
+ }
+ }
+ if (parsed_model_flags.input_shapes.specified()) {
+ QCHECK(uses_multi_input_flags);
+ std::vector<string> input_shapes =
+ absl::StrSplit(parsed_model_flags.input_shapes.value(), ':');
+ QCHECK(input_shapes.size() == model_flags->input_arrays_size());
+ for (int i = 0; i < input_shapes.size(); ++i) {
+ auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape();
+ shape->Clear();
+ if (input_shapes[i].empty()) {
+ // empty i.e. 0-dimensional input shape.
+ // Unfortunately, the current toco::InputArray
+ // proto does not allow to distinguish between a known 0-D shape,
+ // and an unknown shape. Indeed, shape is currently a plain array,
+ // and it being empty means unknown shape. So here, we import a
+ // 0-D shape as a 1-D shape of size.
+ // TODO(benoitjacob): fix toco::InputArray to allow 0-D shape,
+ // probably by making shape an optional message,
+ // encapsulating the array.
+ shape->Add(1);
+ } else {
+ for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) {
+ int size;
+ CHECK(absl::SimpleAtoi(dim_str, &size))
+ << "Failed to parse input_shape: " << input_shapes[i];
+ shape->Add(size);
+ }
+ }
+ }
+ }
+
+#define READ_MODEL_FLAG(name) \
+ do { \
+ if (parsed_model_flags.name.specified()) { \
+ model_flags->set_##name(parsed_model_flags.name.value()); \
+ } \
+ } while (false)
+
+ READ_MODEL_FLAG(variable_batch);
+ READ_MODEL_FLAG(drop_control_dependency);
+
+#undef READ_MODEL_FLAG
+
+ for (const auto& element : parsed_model_flags.rnn_states.value().elements) {
+ auto* rnn_state_proto = model_flags->add_rnn_states();
+ for (const auto& kv_pair : element) {
+ const string& key = kv_pair.first;
+ const string& value = kv_pair.second;
+ if (key == "state_array") {
+ rnn_state_proto->set_state_array(value);
+ } else if (key == "back_edge_source_array") {
+ rnn_state_proto->set_back_edge_source_array(value);
+ } else if (key == "size") {
+ int32 size = 0;
+ CHECK(absl::SimpleAtoi(value, &size));
+ CHECK_GT(size, 0);
+ rnn_state_proto->set_size(size);
+ } else if (key == "manually_create") {
+ CHECK_EQ(absl::AsciiStrToLower(value), "true");
+ rnn_state_proto->set_manually_create(true);
+ } else {
+ LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
+ }
+ }
+ CHECK(rnn_state_proto->has_state_array() &&
+ rnn_state_proto->has_back_edge_source_array() &&
+ rnn_state_proto->has_size())
+ << "--rnn_states must include state_array, back_edge_source_array and "
+ "size.";
+ }
+
+ for (const auto& element : parsed_model_flags.model_checks.value().elements) {
+ auto* model_check_proto = model_flags->add_model_checks();
+ for (const auto& kv_pair : element) {
+ const string& key = kv_pair.first;
+ const string& value = kv_pair.second;
+ if (key == "count_type") {
+ model_check_proto->set_count_type(value);
+ } else if (key == "count_min") {
+ int32 count = 0;
+ CHECK(absl::SimpleAtoi(value, &count));
+ CHECK_GE(count, -1);
+ model_check_proto->set_count_min(count);
+ } else if (key == "count_max") {
+ int32 count = 0;
+ CHECK(absl::SimpleAtoi(value, &count));
+ CHECK_GE(count, -1);
+ model_check_proto->set_count_max(count);
+ } else {
+ LOG(FATAL) << "Unknown key '" << key << "' in --model_checks";
+ }
+ }
+ }
+}
+
+ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
+ static auto* flags = [must_already_exist]() {
+ if (must_already_exist) {
+ fprintf(stderr, __FILE__
+ ":"
+ "GlobalParsedModelFlags() used without initialization\n");
+ fflush(stderr);
+ abort();
+ }
+ return new toco::ParsedModelFlags;
+ }();
+ return flags;
+}
+
+ParsedModelFlags* GlobalParsedModelFlags() {
+ return UncheckedGlobalParsedModelFlags(true);
+}
+
+void ParseModelFlagsOrDie(int* argc, char* argv[]) {
+ // TODO(aselle): in the future allow Google version to use
+ // flags, and only use this mechanism for open source
+ auto* flags = UncheckedGlobalParsedModelFlags(false);
+ string msg;
+ bool model_success =
+ toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags);
+ if (!model_success || !msg.empty()) {
+ // Log in non-standard way since this happens pre InitGoogle.
+ fprintf(stderr, "%s", msg.c_str());
+ fflush(stderr);
+ abort();
+ }
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.h b/tensorflow/contrib/lite/toco/model_cmdline_flags.h
new file mode 100644
index 0000000000..dfa3d3c1ef
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.h
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/args.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+
+namespace toco {
+// Parse and remove arguments for models (in toco). Returns true if parsing
+// is successful. msg has the usage string if there was an error or
+// "--help" was specified
+bool ParseModelFlagsFromCommandLineFlags(
+ int* argc, char* argv[], string* msg,
+ ParsedModelFlags* parsed_model_flags_ptr);
+// Populate the ModelFlags proto with model data.
+void ReadModelFlagsFromCommandLineFlags(
+ const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags);
+// Parse the global model flags to a static
+void ParseModelFlagsOrDie(int* argc, char* argv[]);
+// Get the global parsed model flags
+ParsedModelFlags* GlobalParsedModelFlags();
+
+} // namespace toco
+
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_
diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto
new file mode 100644
index 0000000000..743e08b16f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/model_flags.proto
@@ -0,0 +1,119 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+syntax = "proto2";
+
+package toco;
+
+// Next ID to USE: 5.
+message InputArray {
+ // Name of the input arrays, i.e. the arrays from which input activations
+ // will be read.
+ optional string name = 1;
+
+ // Shape of the input. For many applications the dimensions are {batch,
+ // height, width, depth}. Often the batch is left "unspecified" by providing
+ // a value of -1.
+ //
+ // The last dimension is typically called 'depth' or 'channels'. For example,
+ // for an image model taking RGB images as input, this would have the value 3.
+ repeated int32 shape = 2;
+
+ // mean_value and std_value parameters control the interpretation of raw input
+ // activation values (elements of the input array) as real numbers. The
+ // mapping is given by:
+ //
+ // real_value = (raw_input_value - mean_value) / std_value
+ //
+ // In particular, the defaults (mean_value=0, std_value=1) yield
+ // real_value = raw_input_value. Often, non-default values are used in image
+ // models. For example, an image model taking uint8 image channel values as
+ // its raw inputs, in [0, 255] range, may use mean_value=128, std_value=128 to
+ // map them into the interval [-1, 1).
+ //
+ // Note: this matches exactly the meaning of mean_value and std_value in
+ // (TensorFlow via LegacyFedInput).
+ optional float mean_value = 3;
+ optional float std_value = 4 [default = 1.];
+}
+
+// ModelFlags encodes properties of a model that, depending on the file
+// format, may or may not be recorded in the model file. The purpose of
+// representing these properties in ModelFlags is to allow passing them
+// separately from the input model file, for instance as command-line
+// parameters, so that we can offer a single uniform interface that can
+// handle files from different input formats.
+//
+// For each of these properties, and each supported file format, we
+// detail in comments below whether the property exists in the given file
+// format.
+//
+// Obsolete flags that have been removed:
+// optional int32 input_depth = 3;
+// optional int32 input_width = 4;
+// optional int32 input_height = 5;
+// optional int32 batch = 6 [ default = 1];
+// optional float mean_value = 7;
+// optional float std_value = 8 [default = 1.];
+// optional int32 input_dims = 11 [ default = 4];
+// repeated int32 input_shape = 13;
+//
+// Next ID to USE: 16.
+message ModelFlags {
+ // Information about the input arrays, i.e. the arrays from which input
+ // activations will be read.
+ repeated InputArray input_arrays = 1;
+
+ // Name of the output arrays, i.e. the arrays into which output activations
+ // will be written.
+ repeated string output_arrays = 2;
+
+ // If true, the model accepts an arbitrary batch size. Mutually exclusive with
+ // the 'batch' field: at most one of these two fields can be set.
+ optional bool variable_batch = 10;
+
+ message RnnState {
+ optional string state_array = 1;
+ optional string back_edge_source_array = 2;
+ optional int32 size = 3;
+ // TODO(benoitjacob): manually_create is a temporary hack:
+ // due to discrepancies between the current toco dims tracking and
+ // TensorFlow shapes, for some models we need to manually create RNN state
+ // arrays with a specified shape.
+ // Maybe we should actually implement back-edges as operators of their own,
+ // which would remove the need for much special-casing, including here,
+ // we could probably consistently let PropagateFixedSizes handle state
+ // arrays.
+ optional bool manually_create = 4;
+ }
+ repeated RnnState rnn_states = 12;
+
+ // Checks applied to the model, typically after toco's comprehensive
+ // graph transformations.
+ // Next ID to USE: 4.
+ message ModelCheck {
+ // Use the name of a type of operator to check its counts.
+ // Use "Total" for overall operator counts.
+ // Use "Arrays" for overall array counts.
+ optional string count_type = 1 [default = "None"];
+ // A count of zero is a meaningful check, so negative used to mean disable.
+ optional int32 count_min = 2 [default = -1];
+ // If count_max < count_min, then count_min is only allowed value.
+ optional int32 count_max = 3 [default = -1];
+ }
+ repeated ModelCheck model_checks = 14;
+
+ // If true, ignore control dependency requirements in input TensorFlow
+ // GraphDef. Otherwise an error will be raised upon control dependency inputs.
+ optional bool drop_control_dependency = 15;
+}
diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD
new file mode 100644
index 0000000000..92246a8aed
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/python/BUILD
@@ -0,0 +1,76 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+cc_library(
+ name = "toco_python_api",
+ srcs = ["toco_python_api.cc"],
+ hdrs = ["toco_python_api.h"],
+ deps = [
+ "//tensorflow/contrib/lite/toco:model_flags_proto_cc",
+ "//tensorflow/contrib/lite/toco:toco_flags_proto_cc",
+ "//tensorflow/contrib/lite/toco:toco_port",
+ "//tensorflow/contrib/lite/toco:toco_tooling",
+ "//tensorflow/core:lib",
+ "//util/python:python_headers",
+ ],
+)
+
+tf_py_wrap_cc(
+ name = "tensorflow_wrap_toco",
+ srcs = ["toco.i"],
+ deps = [
+ ":toco_python_api",
+ "//tensorflow/contrib/lite/toco:model_flags_proto_cc",
+ "//tensorflow/contrib/lite/toco:toco_flags_proto_cc",
+ "//util/python:python_headers",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+py_binary(
+ name = "toco_from_protos",
+ srcs = ["toco_from_protos.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":tensorflow_wrap_toco",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_binary(
+ name = "toco_wrapper",
+ srcs = ["toco_wrapper.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+tf_py_test(
+ name = "toco_from_protos_test",
+ srcs = ["toco_from_protos_test.py"],
+ additional_deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/lite/toco:model_flags_proto_py",
+ "//tensorflow/contrib/lite/toco:toco_flags_proto_py",
+ ],
+ data = [
+ ":toco_from_protos",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/toco/python/toco.i b/tensorflow/contrib/lite/toco/python/toco.i
new file mode 100644
index 0000000000..3787cba4a3
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/python/toco.i
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+%include "std_string.i"
+
+%{
+#include "tensorflow/contrib/lite/toco/python/toco_python_api.h"
+%}
+
+namespace toco {
+
+// Convert a model represented in `input_contents`. `model_flags_proto`
+// describes model parameters. `toco_flags_proto` describes conversion
+// parameters (see relevant .protos for more information). Returns a string
+// representing the contents of the converted model.
+PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
+ PyObject* toco_flags_proto_txt_raw,
+ PyObject* input_contents_txt_raw);
+
+} // namespace toco \ No newline at end of file
diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos.py b/tensorflow/contrib/lite/toco/python/toco_from_protos.py
new file mode 100644
index 0000000000..c0b032083b
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/python/toco_from_protos.py
@@ -0,0 +1,63 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python console command to invoke TOCO from serialized protos."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+from tensorflow.contrib.lite.toco.python import tensorflow_wrap_toco
+from tensorflow.python.platform import app
+
+FLAGS = None
+
+
+def execute(unused_args):
+ model_str = open(FLAGS.model_proto_file, "rb").read()
+ toco_str = open(FLAGS.toco_proto_file, "rb").read()
+ input_str = open(FLAGS.model_input_file, "rb").read()
+
+ output_str = tensorflow_wrap_toco.TocoConvert(model_str, toco_str, input_str)
+ open(FLAGS.model_output_file, "wb").write(output_str)
+ sys.exit(0)
+
+
+def main():
+ global FLAGS
+ parser = argparse.ArgumentParser(
+ description="Invoke toco using protos as input.")
+ parser.add_argument(
+ "model_proto_file",
+ type=str,
+ help="File containing serialized proto that describes the model.")
+ parser.add_argument(
+ "toco_proto_file",
+ type=str,
+ help="File containing serialized proto describing how TOCO should run.")
+ parser.add_argument(
+ "model_input_file", type=str, help="Input model is read from this file.")
+ parser.add_argument(
+ "model_output_file",
+ type=str,
+ help="Result of applying TOCO conversion is written here.")
+
+ FLAGS, unparsed = parser.parse_known_args()
+
+ app.run(main=execute, argv=[sys.argv[0]] + unparsed)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
new file mode 100644
index 0000000000..2a593beeca
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
@@ -0,0 +1,96 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+import tensorflow as tf
+from tensorflow.contrib.lite.toco import model_flags_pb2
+from tensorflow.contrib.lite.toco import toco_flags_pb2
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import resource_loader
+
+
+def TensorName(x):
+ """Get the canonical (non foo:0 name)."""
+ return x.name.split(":")[0]
+
+
+class TocoFromProtosTest(googletest.TestCase):
+
+ def _run(self, sess, in_tensor, out_tensor, should_succeed):
+ """Use toco binary to check conversion from graphdef to tflite.
+
+ Args:
+ sess: Active TensorFlow session containing graph.
+ in_tensor: TensorFlow tensor to use as input.
+ out_tensor: TensorFlow tensor to use as output.
+ should_succeed: Whether this is a valid conversion.
+ """
+ # Build all protos and extract graphdef
+ graph_def = sess.graph_def
+ toco_flags = toco_flags_pb2.TocoFlags()
+ toco_flags.input_format = toco_flags_pb2.TENSORFLOW_GRAPHDEF
+ toco_flags.output_format = toco_flags_pb2.TFLITE
+ toco_flags.input_types.append(toco_flags_pb2.FLOAT)
+ toco_flags.inference_type = toco_flags_pb2.FLOAT
+ model_flags = model_flags_pb2.ModelFlags()
+ input_array = model_flags.input_arrays.add()
+ input_array.name = TensorName(in_tensor)
+ input_array.shape.extend(map(int, in_tensor.get_shape()))
+ model_flags.output_arrays.append(TensorName(out_tensor))
+ # Shell out to run toco (in case it crashes)
+ with tempfile.NamedTemporaryFile() as fp_toco, \
+ tempfile.NamedTemporaryFile() as fp_model, \
+ tempfile.NamedTemporaryFile() as fp_input, \
+ tempfile.NamedTemporaryFile() as fp_output:
+ fp_model.write(model_flags.SerializeToString())
+ fp_toco.write(toco_flags.SerializeToString())
+ fp_input.write(graph_def.SerializeToString())
+ fp_model.flush()
+ fp_toco.flush()
+ fp_input.flush()
+ tflite_bin = resource_loader.get_path_to_datafile("toco_from_protos")
+ cmdline = " ".join([
+ tflite_bin, fp_model.name, fp_toco.name, fp_input.name, fp_output.name
+ ])
+ exitcode = os.system(cmdline)
+ if exitcode == 0:
+ stuff = fp_output.read()
+ self.assertEqual(stuff is not None, should_succeed)
+ else:
+ self.assertFalse(should_succeed)
+
+ def test_toco(self):
+ """Run a couple of TensorFlow graphs against TOCO through the python bin."""
+ with tf.Session() as sess:
+ img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
+ val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
+ out = tf.identity(val, name="out")
+ out2 = tf.sin(val, name="out2")
+ # This is a valid mdoel
+ self._run(sess, img, out, True)
+ # This uses an invalid function.
+ # TODO(aselle): Check to make sure a warning is included.
+ self._run(sess, img, out2, True)
+ # This is an identity graph, which doesn't work
+ self._run(sess, img, img, False)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.cc b/tensorflow/contrib/lite/toco/python/toco_python_api.cc
new file mode 100644
index 0000000000..8a5e483f3f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/python/toco_python_api.cc
@@ -0,0 +1,85 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string>
+#include <vector>
+#include "tensorflow/core/platform/logging.h"
+
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/python/toco_python_api.h"
+#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_tooling.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+
+namespace toco {
+
+#if PY_MAJOR_VERSION >= 3
+#define TOCO_PY_TO_CPPSTRING PyBytes_AsStringAndSize
+#define TOCO_FROM_CPPSTRING_TO_PY PyBytes_FromStringAndSize
+#else
+#define TOCO_PY_TO_CPPSTRING PyString_AsStringAndSize
+#define TOCO_FROM_CPPSTRING_TO_PY PyString_FromStringAndSize
+#endif
+
+// NOTE(aselle): We are using raw PyObject's here because we want to make
+// sure we input and output bytes rather than unicode strings for Python3.
+PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
+ PyObject* toco_flags_proto_txt_raw,
+ PyObject* input_contents_txt_raw) {
+ // Use Python C API to validate and convert arguments. In py3 (bytes),
+ // in py2 (str).
+ auto ConvertArg = [&](PyObject* obj, bool* error) {
+ char* buf;
+ Py_ssize_t len;
+ if (TOCO_PY_TO_CPPSTRING(obj, &buf, &len) == -1) {
+ *error = true;
+ return std::string();
+ } else {
+ *error = false;
+ return std::string(buf, len);
+ }
+ };
+
+ bool error;
+ std::string model_flags_proto_txt =
+ ConvertArg(model_flags_proto_txt_raw, &error);
+ if (error) return nullptr;
+ std::string toco_flags_proto_txt =
+ ConvertArg(toco_flags_proto_txt_raw, &error);
+ if (error) return nullptr;
+ std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
+ if (error) return nullptr;
+
+ // Use toco to produce new outputs
+ toco::ModelFlags model_flags;
+ if (!model_flags.ParseFromString(model_flags_proto_txt)) {
+ LOG(FATAL) << "Model proto failed to parse." << std::endl;
+ }
+ toco::TocoFlags toco_flags;
+ if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
+ LOG(FATAL) << "Toco proto failed to parse." << std::endl;
+ }
+ std::unique_ptr<toco::Model> model =
+ toco::Import(toco_flags, model_flags, input_contents_txt);
+ toco::Transform(toco_flags, model.get());
+ string output_file_contents_txt;
+ Export(toco_flags, *model, &output_file_contents_txt);
+
+ // Convert arguments back to byte (py3) or str (py2)
+ return TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(),
+ output_file_contents_txt.size());
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.h b/tensorflow/contrib/lite/toco/python/toco_python_api.h
new file mode 100644
index 0000000000..dc378353f7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/python/toco_python_api.h
@@ -0,0 +1,33 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
+#define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
+
+#include <string>
+#include <Python.h>
+
+namespace toco {
+
+// Convert a model represented in `input_contents`. `model_flags_proto`
+// describes model parameters. `toco_flags_proto` describes conversion
+// parameters (see relevant .protos for more information). Returns a string
+// representing the contents of the converted model.
+PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
+ PyObject* toco_flags_proto_txt_raw,
+ PyObject* input_contents_txt_raw);
+
+} // namespace toco
+
+#endif // _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
diff --git a/tensorflow/contrib/lite/toco/python/toco_wrapper.py b/tensorflow/contrib/lite/toco/python/toco_wrapper.py
new file mode 100644
index 0000000000..e39b5f22c7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/python/toco_wrapper.py
@@ -0,0 +1,35 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Wrapper for runninmg toco binary embedded in pip site-package.
+
+NOTE: this mainly exists since PIP setup.py cannot install binaries to bin/.
+It can only install Python "console-scripts." This will work as a console
+script. See tools/pip_package/setup.py (search for CONSOLE_SCRIPTS).
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import tensorflow as tf
+
+
+def main():
+ # Pip installs the binary in aux-bin off of main site-package install.
+ # Just find it and exec, passing all arguments in the process.
+ # TODO(aselle): it is unfortunate to use all of tensorflow to lookup binary.
+ binary = os.path.join(tf.__path__[0], 'aux-bin/toco')
+ os.execvp(binary, sys.argv)
diff --git a/tensorflow/contrib/lite/toco/runtime/common.h b/tensorflow/contrib/lite/toco/runtime/common.h
new file mode 100644
index 0000000000..bd55544f57
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/runtime/common.h
@@ -0,0 +1,26 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_
+
+#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
+#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#endif
+#endif
+
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_
diff --git a/tensorflow/contrib/lite/toco/runtime/types.h b/tensorflow/contrib/lite/toco/runtime/types.h
new file mode 100644
index 0000000000..df63b2d59e
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/runtime/types.h
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_
+
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace toco {
+
+// TODO(ahentz): These are just stopgaps for now, untils we move all
+// the code over to tflite.
+using tflite::Dims;
+using tflite::FusedActivationFunctionType;
+using tflite::RequiredBufferSizeForDims;
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD
new file mode 100644
index 0000000000..0c1a1141fc
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD
@@ -0,0 +1,102 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+cc_library(
+ name = "cluster_utils",
+ srcs = [
+ "cluster_utils.cc",
+ ],
+ hdrs = [
+ "cluster_utils.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite/toco:toco_port",
+ ],
+)
+
+cc_library(
+ name = "cluster",
+ srcs = [
+ "cluster.cc",
+ ],
+ hdrs = [
+ "cluster.h",
+ ],
+ deps = [
+ ":cluster_utils",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "resolve_svdf",
+ srcs = [
+ "resolve_svdf.cc",
+ ],
+ hdrs = [
+ "resolve_svdf.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cluster",
+ ":cluster_utils",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:toco_port",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+tf_cc_test(
+ name = "resolve_svdf_test",
+ srcs = ["resolve_svdf_test.cc"],
+ deps = [
+ ":cluster",
+ ":cluster_utils",
+ ":resolve_cluster",
+ ":resolve_svdf",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "resolve_cluster",
+ srcs = [
+ "resolve_cluster.cc",
+ ],
+ hdrs = [
+ "resolve_cluster.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cluster",
+ ":cluster_utils",
+ ":resolve_svdf",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc
new file mode 100644
index 0000000000..98a130ea39
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc
@@ -0,0 +1,52 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h"
+
+namespace toco {
+
+void Cluster::SetGraphDefInfo(const tensorflow::GraphDef* graph_def) {
+ graph_def_ = graph_def;
+ for (const tensorflow::NodeDef& node : graph_def_->node()) {
+ if (StrContains(node.name(), name_)) {
+ nodes_.push_back(&node);
+ }
+ }
+}
+
+bool Cluster::FindClusterInputsAndOutputs() {
+ // For every node N in the graph:
+ // If N belongs to this cluster C, then each of N's inputs that are not part
+ // of C are then inputs of C.
+ // If N does not belong to cluster C, then each of N's inputs that belong to C
+ // are then outputs of C.
+ for (const tensorflow::NodeDef& node : graph_def_->node()) {
+ if (StrContains(node.name(), name_)) {
+ for (int i = 0; i < node.input_size(); i++) {
+ if (!StrContains(node.input(i), name_)) {
+ inputs_.push_back(node.input(i));
+ }
+ }
+ } else {
+ for (int i = 0; i < node.input_size(); i++) {
+ if (StrContains(node.input(i), name_)) {
+ outputs_.push_back(node.input(i));
+ }
+ }
+ }
+ }
+ return (!inputs_.empty()) && (!outputs_.empty());
+}
+
+} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h
new file mode 100644
index 0000000000..18ff73ac39
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h
@@ -0,0 +1,101 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+
+namespace toco {
+
+// The base class for Cluster. A cluster is group of nodes all related to each
+// other because their name match a given "pattern", which shows they all belong
+// to a composite op supported in TFLite. The nodes in a cluster will be
+// collapsed into a single composite op node plus a series of constant nodes
+// holding the input parameters to that node. The nodes in a cluster are assumed
+// to be using the same device. By changing the "pattern" we can have different
+// subclasses of the base Cluster class.
+class Cluster {
+ public:
+ virtual ~Cluster() {}
+
+ virtual void CreateNodes() = 0;
+
+ // Save the following info from the original GraphDef this cluster is from:
+ // 1- a pointer to the GraphDef
+ // 2- All the nodes in GraphDef which belong to this cluster.
+ void SetGraphDefInfo(const tensorflow::GraphDef* graph_def);
+
+ const string& GetName() const { return name_; }
+
+ const std::vector<std::unique_ptr<tensorflow::NodeDef>>& GetNewNodes() const {
+ return new_nodes_;
+ }
+
+ const std::vector<const tensorflow::NodeDef*>& GetNodes() { return nodes_; }
+
+ void SetName(const string& name) { name_ = name; }
+
+ void SetDevice(const string& device) { device_ = device; }
+
+ // Find the input(s) and output(s) of this Cluster.
+ bool FindClusterInputsAndOutputs();
+
+ protected:
+ string name_;
+ string device_;
+ std::vector<string> inputs_;
+ std::vector<string> outputs_;
+
+ // Used to hold the pointers to nodes which are in this cluster. These nodes
+ // are pointing to the nodes in graph_def_.
+ std::vector<const tensorflow::NodeDef*> nodes_;
+
+ // Used to cache the newly generated nodes: like the nodes created by
+ // collapsing Const nodes, or the nodes which is used to show the composite
+ // op.
+ std::vector<std::unique_ptr<tensorflow::NodeDef>> new_nodes_;
+
+ const tensorflow::GraphDef* graph_def_; /*Not owned*/
+};
+
+// A factory interface for cluster class.
+// It defines a virtual function interface which is responsible for creating
+// a cluster. Each cluster factory is responsible to pack a cluster of nodes
+// into a cluster using a name-based pattern matching approach.
+class ClusterFactoryInterface {
+ public:
+ virtual ~ClusterFactoryInterface() {}
+
+ // Creates a cluster of nodes using a name-based pattern matching approach. It
+ // uses a node as a seed and if its name matches a certain pattern, then it
+ // builds the cluster around that node.
+ virtual std::unique_ptr<Cluster> CreateCluster(
+ const tensorflow::NodeDef& node,
+ const tensorflow::GraphDef& graph_def) const = 0;
+};
+
+} // end namespace toco
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc
new file mode 100644
index 0000000000..14c3cd6487
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string>
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+namespace toco {
+
+bool StrContains(const string& x, const string& search_pattern) {
+ return x.find(search_pattern) != string::npos;
+}
+
+void Transpose2DTensor(const float* tensor, int row, int col,
+ float* transposed_tensor) {
+ float* result = transposed_tensor;
+ for (int r = 0; r < row; ++r) {
+ for (int c = 0; c < col; ++c) {
+ *(result + c * row) = *tensor++;
+ }
+ ++result;
+ }
+}
+
+} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h
new file mode 100644
index 0000000000..a15e480e70
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h
@@ -0,0 +1,33 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H
+
+#include <string>
+
+namespace toco {
+
+// Check if string x includes string search_pattern.
+bool StrContains(const string& x, const string& search_pattern);
+
+// Transpose a 2D tensor of size row * col pointed by "tensor" and return the
+// results in "transposed_tensor". "transposed_tensor" must be pre-allocated
+// by the same size as "tensor".
+void Transpose2DTensor(const float* tensor, int row, int col,
+ float* transposed_tensor);
+
+} // end namespace toco
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc
new file mode 100644
index 0000000000..fddf6cc836
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc
@@ -0,0 +1,151 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+
+namespace toco {
+
+using tensorflow::GraphDef;
+using tensorflow::NodeDef;
+
+void AddNodeToGraph(const NodeDef& node,
+ const std::vector<string>& cluster_names, GraphDef* graph) {
+ NodeDef* new_node = graph->add_node();
+ new_node->set_op(node.op());
+ new_node->set_name(node.name());
+ new_node->set_device(node.device());
+ // If the inputs are coming from a node which belongs to another cluster, then
+ // those inputs are renamed to the source cluster name. Otherwise the original
+ // input name is used.
+ for (const string& node_input : node.input()) {
+ bool input_from_cluster = false;
+ for (const string& cluster_name : cluster_names) {
+ if (StrContains(node_input, cluster_name) &&
+ !StrContains(node.name(), cluster_name)) {
+ new_node->add_input(cluster_name);
+ input_from_cluster = true;
+ break;
+ }
+ }
+ if (!input_from_cluster) {
+ new_node->add_input(node_input);
+ }
+ }
+ for (const auto& attr : node.attr()) {
+ (*new_node->mutable_attr())[attr.first] = attr.second;
+ }
+}
+
+bool FindCluster(const ClusterFactoryInterface& cluster_factory,
+ const GraphDef& graph_def,
+ std::unordered_map<string, bool>* is_node_in_cluster,
+ std::vector<std::unique_ptr<Cluster>>* clusters) {
+ for (const NodeDef& node : graph_def.node()) {
+ // If the node is not assigned to any cluster, then we check if it belong to
+ // the cluster_factory.
+ bool node_in_cluster = (*is_node_in_cluster)[node.name()];
+ if (!node_in_cluster) {
+ std::unique_ptr<Cluster> cluster =
+ cluster_factory.CreateCluster(node, graph_def);
+ if (cluster) {
+ // Label all the nodes in is_node_in_cluster which are in this cluster
+ // as belonged to this cluster.
+ for (const NodeDef* cluster_node : cluster->GetNodes()) {
+ (*is_node_in_cluster)[cluster_node->name()] = true;
+ }
+ clusters->push_back(std::move(cluster));
+ }
+ }
+ }
+ return (!clusters->empty());
+}
+
+std::unique_ptr<GraphDef> MaybeResolveClusters(
+ const GraphDef& graph_def,
+ const std::vector<ClusterFactoryInterface*>& cluster_factories) {
+ std::unique_ptr<GraphDef> pruned_graph(new GraphDef);
+ // The structure to keep track of which cluster each node is assigned to, and
+ // to initialize them to all un-assigned,
+ std::unordered_map<string, bool> is_node_in_cluster;
+ for (const NodeDef& node : graph_def.node()) {
+ is_node_in_cluster[node.name()] = false;
+ }
+
+ std::vector<string> cluster_names;
+ std::vector<std::unique_ptr<Cluster>> all_clusters;
+ // Find the clusters for all available cluster factories.
+ for (const ClusterFactoryInterface* cluster_factory : cluster_factories) {
+ std::vector<std::unique_ptr<Cluster>> clusters;
+ if (FindCluster(*cluster_factory, graph_def, &is_node_in_cluster,
+ &clusters)) {
+ for (auto itr = clusters.begin(); itr != clusters.end(); ++itr) {
+ cluster_names.push_back((*itr)->GetName());
+ (*itr)->CreateNodes();
+ all_clusters.push_back(std::move(*itr));
+ }
+ }
+ }
+
+ for (const std::unique_ptr<Cluster>& cluster : all_clusters) {
+ for (const std::unique_ptr<tensorflow::NodeDef>& src_node :
+ cluster->GetNewNodes()) {
+ // Add it to the output GraphDef.
+ AddNodeToGraph(*src_node, cluster_names, pruned_graph.get());
+ }
+ }
+
+ // Add any node which is not part of a cluster.
+ for (const NodeDef& node : graph_def.node()) {
+ bool node_in_cluster = is_node_in_cluster[node.name()];
+ if (!node_in_cluster) {
+ AddNodeToGraph(node, cluster_names, pruned_graph.get());
+ }
+ }
+
+ if (pruned_graph->node_size() == 0) {
+ return nullptr;
+ } else {
+ return pruned_graph;
+ }
+}
+
+std::unique_ptr<GraphDef> MaybeReplaceCompositeSubgraph(
+ const GraphDef& tf_graph) {
+ SvdfClusterFactory svdf_cluster_factory;
+
+ std::vector<ClusterFactoryInterface*> cluster_factories;
+ cluster_factories.push_back(&svdf_cluster_factory);
+
+ std::unique_ptr<GraphDef> pruned_graph =
+ MaybeResolveClusters(tf_graph, cluster_factories);
+
+ // Copy function definitions
+ *(pruned_graph->mutable_library()) = tf_graph.library();
+ return pruned_graph;
+}
+
+} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h
new file mode 100644
index 0000000000..7d33dd1885
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+
+namespace toco {
+
+// Given a graph info and a list of cluster classes (cluster_factories), it
+// partitions the graph to clusters, and then collapses each cluster into their
+// corresponding composite ops. It generates a new graph using the newly
+// generated composite ops. Each cluster factory is responsible to recognize a
+// cluster of nodes into a cluster using a name-based pattern matching approach.
+std::unique_ptr<tensorflow::GraphDef> MaybeResolveClusters(
+ const tensorflow::GraphDef& graph_def,
+ const std::vector<ClusterFactoryInterface*>& cluster_factories);
+
+// Adds a node to a given graph. The added node will be a copy of a given source
+// node, except for the inputs. If the inputs are coming from a node which
+// belongs to another cluster, then those inputs are renamed to the source
+// cluster name.
+void AddNodeToGraph(const tensorflow::NodeDef& node,
+ const std::vector<string>& cluster_names,
+ tensorflow::GraphDef* graph);
+
+// Given a graph and a cluster class, it finds all the nodes which belong to a
+// given class factory, encapsulate them inside a cluster of the given type and
+// returns a vector of those clusters. It also labels the nodes in that graph if
+// they belong to the generated clusters.
+bool FindCluster(const ClusterFactoryInterface& cluster_factory,
+ const tensorflow::GraphDef& graph_def,
+ std::unordered_map<string, bool>* is_node_in_cluster,
+ std::vector<std::unique_ptr<Cluster>>* clusters);
+
+// Receives a graph and generates another graph by replacing the cluster of
+// nodes which matches a given composite op. Each composite op is represented
+// using a class factory.
+std::unique_ptr<tensorflow::GraphDef> MaybeReplaceCompositeSubgraph(
+ const tensorflow::GraphDef& tf_graph);
+
+} // end namespace toco
+
+#endif // CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc
new file mode 100644
index 0000000000..d6a099817c
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc
@@ -0,0 +1,285 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
+
+#include <ctype.h>
+#include <stddef.h>
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "google/protobuf/map.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/platform/logging.h"
+
+using tensorflow::GraphDef;
+using tensorflow::NodeDef;
+
+namespace toco {
+
+namespace {
+
+// Receives a vector of cluster nodes and returns only those which are array
+// partitions (of type 'Const' and have the pattern 'part_<.*>' in their name.
+// Since these nodes are connected to a Concatenate node, it makes sure the
+// axis value input of the Concatenate operator is 0.
+void FilterPartitionedConstNodes(
+ const string& const_pattern,
+ const std::vector<const NodeDef*>& cluster_nodes,
+ std::vector<const NodeDef*>* const_node_parts) {
+ for (const NodeDef* node : cluster_nodes) {
+ string node_name_to_upper = node->name();
+ std::transform(node_name_to_upper.begin(), node_name_to_upper.end(),
+ node_name_to_upper.begin(), ::toupper);
+ if (StrContains(node->name(), const_pattern) && node->op() == "Const") {
+ if (StrContains(node_name_to_upper, "/PART_")) {
+ const_node_parts->push_back(node);
+ } else if (StrContains(node->name(), "AXIS") &&
+ StrContains(node->name(), "CONCAT")) {
+ // For now only supporting Concatenate on Axix 0
+ const auto& value_attr = node->attr().at("value");
+ const tensorflow::TensorProto& tensor = value_attr.tensor();
+ CHECK_EQ(tensor.int_val(0), 0);
+ }
+ }
+ }
+ sort(const_node_parts->begin(), const_node_parts->end(),
+ [](const NodeDef* a, const NodeDef* b) {
+ return (a->name().compare(b->name()) < 0 &&
+ (a->name().size() < b->name().size()));
+ });
+}
+
+} // namespace
+
+// SvdfCluster methods
+
+int SvdfCluster::InferFilterRank() {
+ for (const NodeDef* node : nodes_) {
+ if (StrContains(node->name(), "Reshape/shape")) {
+ const auto& value_attr = node->attr().at("value");
+ const tensorflow::TensorProto& tensor = value_attr.tensor();
+ std::vector<int32> shape_values(
+ tensor.tensor_content().size() / sizeof(int), 0);
+ port::CopyToBuffer(tensor.tensor_content(),
+ reinterpret_cast<char*>(shape_values.data()));
+ CHECK_EQ(shape_values.size(), 3);
+ // shape_value array is arranged as:
+ // [num_units, rank, -1]
+ CHECK_EQ(shape_values[2], -1);
+ return shape_values[1];
+ }
+ }
+ return -1;
+}
+
+void SvdfCluster::CreateNodes() {
+ for (const string& const_pattern : const_node_patterns_) {
+ CreateConstNode(const_pattern);
+ }
+ std::unique_ptr<tensorflow::NodeDef> svdf_node(new NodeDef);
+ svdf_node->set_op("Svdf");
+ svdf_node->set_name(name_);
+ svdf_node->set_device(device_);
+
+ // Add the main input.
+ svdf_node->add_input(inputs_[0]);
+
+ // Add the rest of the inputs to Svdf cell: weights and bias.
+ CHECK(new_nodes_.size() == 3 || new_nodes_.size() == 2);
+ string* weights_feature_input = svdf_node->add_input();
+ string* weights_time_input = svdf_node->add_input();
+ string* bias_input;
+ if (new_nodes_.size() == 3) {
+ bias_input = svdf_node->add_input();
+ }
+ for (const std::unique_ptr<tensorflow::NodeDef>& node : new_nodes_) {
+ const string node_name = node->name();
+ if (StrContains(node_name, "SVDF_weights_feature")) {
+ *weights_feature_input = node_name;
+ } else if (StrContains(node_name, "SVDF_weights_time")) {
+ *weights_time_input = node_name;
+ } else if (StrContains(node_name, "SVDF_bias")) {
+ CHECK(bias_input) << "Bias input cannot be provided when there are only "
+ "two Const input nodes!";
+ *bias_input = node_name;
+ } else {
+ // Unexpected input for Svdf op.
+ LOG(FATAL) << "Unexpected input node for SVDF op! Accepted inputs are: "
+ "weights_feature, weights_time and bias.";
+ }
+ }
+ const int rank = InferFilterRank();
+ CHECK_GT(rank, 0);
+
+ // Add Svdf activation and rank.
+ string activation_function =
+ StrContains(outputs_[0], "Relu") ? "Relu" : "None";
+ (*svdf_node->mutable_attr())["ActivationFunction"].set_s(activation_function);
+ (*svdf_node->mutable_attr())["Rank"].set_i(rank);
+
+ // Finally add it to the list of the newly created nodes.
+ new_nodes_.push_back(std::move(svdf_node));
+}
+
+void SvdfCluster::CreateConstNode(const string& const_pattern) {
+ // Find the nodes with pattern like: "const_pattern"/part_xxx of type Const.
+ std::vector<const NodeDef*> const_node_parts;
+ FilterPartitionedConstNodes(const_pattern, nodes_, &const_node_parts);
+
+ if (const_node_parts.empty()) return;
+
+ bool transpose_tensor_value =
+ StrContains(const_pattern, "SVDF_weights_feature");
+
+ // Merge them if necessary.
+ std::unique_ptr<tensorflow::NodeDef> merged_node(new NodeDef);
+ MaybeMergeConstNodes(const_node_parts, transpose_tensor_value, merged_node);
+ new_nodes_.push_back(std::move(merged_node));
+}
+
+void SvdfCluster::MaybeMergeConstNodes(
+ const std::vector<const NodeDef*>& const_node_parts,
+ bool transpose_tensor_value,
+ const std::unique_ptr<tensorflow::NodeDef>& merged_node) {
+ merged_node->set_name(const_node_parts[0]->name());
+ merged_node->set_op("Const");
+ merged_node->set_device(const_node_parts[0]->device());
+ (*merged_node->mutable_attr())["dtype"].set_type(
+ const_node_parts[0]->attr().at("dtype").type());
+
+ // Figuring out Value attribute for the merged node.
+ // Assuming the partitioning is done on Axis 0.
+ // The attributes which are inferred:
+ // * Shape and dimensions
+ // * Float content values
+
+ // Inferring shape and dimension
+ int dim0_size = 0;
+ int dim1_size = 1;
+ tensorflow::TensorProto* allocated_tensor =
+ (*merged_node->mutable_attr())["value"].mutable_tensor();
+ tensorflow::TensorShapeProto* allocated_tensor_shape =
+ allocated_tensor->mutable_tensor_shape();
+ auto tensor_shape_dim0 = allocated_tensor_shape->add_dim();
+ int allocated_content_flat_size = 0;
+ for (int i = 0; i < const_node_parts.size(); i++) {
+ const auto& value_attr = const_node_parts[i]->attr().at("value");
+ const tensorflow::TensorProto& tensor = value_attr.tensor();
+ if (i == 0) {
+ allocated_tensor->set_dtype(tensor.dtype());
+ } else {
+ CHECK_EQ(allocated_tensor->dtype(), tensor.dtype());
+ }
+ allocated_content_flat_size += tensor.tensor_content().size();
+ CHECK(tensor.has_tensor_shape());
+ const tensorflow::TensorShapeProto shape = tensor.tensor_shape();
+ dim0_size += shape.dim(0).size();
+ for (int d = 1; d < shape.dim_size(); d++) {
+ if (i == 0) {
+ allocated_tensor_shape->add_dim()->set_size(shape.dim(d).size());
+ allocated_tensor_shape->set_unknown_rank(shape.unknown_rank());
+ dim1_size *= shape.dim(d).size();
+ } else {
+ CHECK_EQ(shape.dim(d).size(), allocated_tensor_shape->dim(d).size());
+ CHECK_EQ(allocated_tensor_shape->unknown_rank(), shape.unknown_rank());
+ }
+ }
+ }
+
+ // Copying the float content from each array partition.
+ std::unique_ptr<char[]> allocated_content(
+ new char[allocated_content_flat_size]);
+ char* content_ptr = allocated_content.get();
+ for (int i = 0; i < const_node_parts.size(); i++) {
+ const auto& value_attr = const_node_parts[i]->attr().at("value");
+ const tensorflow::TensorProto& tensor = value_attr.tensor();
+ port::CopyToBuffer(tensor.tensor_content(), content_ptr);
+ content_ptr += tensor.tensor_content().size();
+ }
+
+ // Transpose the tensor if needed.
+ if (transpose_tensor_value) {
+ // We use dimension 0 to show the row size for the tensor.
+ // We use multiplication of the rest of dimension size to for the col size
+ // of the tensor.
+ std::unique_ptr<float[]> transposed_tensor(
+ new float[dim0_size * dim1_size]);
+ Transpose2DTensor(reinterpret_cast<float*>(allocated_content.get()),
+ dim0_size, dim1_size, transposed_tensor.get());
+ allocated_tensor_shape->clear_dim();
+ allocated_tensor_shape->add_dim()->set_size(dim1_size);
+ allocated_tensor_shape->add_dim()->set_size(dim0_size);
+
+ // Set the tensor attributes.
+ allocated_tensor->set_tensor_content(
+ string(reinterpret_cast<const char*>(transposed_tensor.get()),
+ allocated_content_flat_size));
+ } else {
+ tensor_shape_dim0->set_size(dim0_size);
+
+ // Set the tensor attributes.
+ allocated_tensor->set_tensor_content(
+ string(reinterpret_cast<const char*>(allocated_content.get()),
+ allocated_content_flat_size));
+ }
+}
+
+// SvdfClusterFactory methods
+
+std::unique_ptr<Cluster> SvdfClusterFactory::CreateCluster(
+ const NodeDef& node, const GraphDef& graph_def) const {
+ std::vector<string> node_patterns = {"SVDF_weights_feature",
+ "SVDF_weights_time", "SVDF_bias"};
+
+ string node_name_to_upper = node.name();
+ std::transform(node_name_to_upper.begin(), node_name_to_upper.end(),
+ node_name_to_upper.begin(), ::toupper);
+ std::unique_ptr<SvdfCluster> cluster = nullptr;
+ if (node_name_to_upper.find("SVDF", 0) != string::npos) {
+ size_t weights_pos = node.name().find(node_patterns[0]);
+ if (weights_pos != string::npos) {
+ // Assuming the node name has a pattern like:
+ // "SOMESTRING1/CELLNAME/SEARCH_PATTERN/SOMESTRING2", we use
+ // CELLNAME as the cluster name.
+ size_t cell_pos = node.name().rfind("/", weights_pos - 2) + 1;
+ string cell_name =
+ node.name().substr(cell_pos, weights_pos - cell_pos - 1);
+ cluster = std::unique_ptr<SvdfCluster>(new SvdfCluster);
+ cluster->SetName(cell_name);
+ cluster->SetDevice(node.device());
+ cluster->SetGraphDefInfo(&graph_def);
+ CHECK(cluster->FindClusterInputsAndOutputs());
+
+ for (const string& const_pattern : node_patterns) {
+ cluster->AddConstNodePattern(const_pattern);
+ }
+ }
+ }
+ return std::move(cluster);
+}
+
+} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h
new file mode 100644
index 0000000000..c4c6c34117
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h
@@ -0,0 +1,82 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+
+namespace toco {
+
+class SvdfCluster : public Cluster {
+ public:
+ // For this cluster, it collapses all the nodes in nodes_ into a composite op
+ // and it returns all the newly generated ops in new_nodes_.
+ void CreateNodes() override;
+
+ // A helper function to set the pattern of Const nodes which CreateNodes()
+ // should handle specially.
+ void AddConstNodePattern(const string& const_pattern) {
+ const_node_patterns_.push_back(const_pattern);
+ }
+
+ virtual ~SvdfCluster() {}
+
+ private:
+ // The main function which is used to create Const nodes for this cluster.
+ // These Const nodes are the inputs to the composite op generated for this
+ // cluster.
+ void CreateConstNode(const string& const_pattern);
+
+ // Receives a vector of Const nodes, merge them (if necessary) and returns
+ // only one Const node holding all the arrays contents. It transposes it if
+ // needed.
+ void MaybeMergeConstNodes(
+ const std::vector<const tensorflow::NodeDef*>& const_node_parts,
+ bool transpose_tensor_value,
+ const std::unique_ptr<tensorflow::NodeDef>& merged_node);
+
+ // Infer the value of Svdf filter rank, by looking up a reshape operator which
+ // is used for 'output' which reshapes output from [num_filters, batch, 1]
+ // shape to [num_units, rank, batch] shape. The 2nd shape element is rank.
+ int InferFilterRank();
+
+ std::vector<string> const_node_patterns_;
+};
+
+class SvdfClusterFactory : public ClusterFactoryInterface {
+ public:
+ // Creates a cluster of nodes using a name-based pattern matching approach. It
+ // uses a node as a seed and if its name matches a certain pattern, then it
+ // builds the cluster around that node.
+ // This factory expects nodes which have "SVDF_weights_feature" and
+ // "SVDF_weights_time" pattern in their names (and optionally "SVDF_bias")
+ // and it creates an SVDF Op from them.
+ std::unique_ptr<Cluster> CreateCluster(
+ const tensorflow::NodeDef& node,
+ const tensorflow::GraphDef& graph_def) const;
+};
+
+} // end namespace toco
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc
new file mode 100644
index 0000000000..664e828c19
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc
@@ -0,0 +1,212 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/logging.h"
+
+using tensorflow::GraphDef;
+using tensorflow::NodeDef;
+
+namespace toco {
+
+class ResolveSvdfTest : public ::testing::Test {
+ public:
+ ResolveSvdfTest() {
+ AddNewNode("Input1", "Const", {});
+ AddNewNode("Svdf1/SVDF_weights_feature/part_0", "Const", {},
+ {0.1, 0.2, 0.3});
+ AddNewNode("Svdf1/SVDF_weights_feature/part_0/read", "Identity",
+ {"Svdf1/SVDF_weights_feature/part_0"});
+ AddNewNode("Svdf1/SVDF_weights_time/part_0", "Const", {}, {0.1, 0.2, 0.3});
+ AddNewNode("Svdf1/SVDF_weights_time/part_0/read", "Identity",
+ {"Svdf1/SVDF_weights_time/part_0"});
+
+ AddNewNode("Svdf1/f1", "SVDF_F1",
+ {"Input1", "Svdf1/SVDF_weights_feature/part_0/read"});
+ AddNewNode("Svdf1/f2", "SVDF_F2",
+ {"Svdf1/SVDF_weights_time/part_0/read", "Svdf1/f1"});
+ AddNewNode("Svdf1/Relu", "Relu", {"Svdf1/f2"});
+ AddShapeNode("Svdf1/Reshape/shape", {10, 1, -1});
+ AddNewNode("Output1", "Const", {"Svdf1/Relu"});
+
+ AddNewNode("Input2", "Const", {});
+ AddNewNode("Svdf2/SVDF_weights_feature/part_0", "Const", {},
+ {0.1, 0.2, 0.3});
+ AddNewNode("Svdf2/SVDF_weights_feature/part_0/read", "Identity",
+ {"Svdf2/SVDF_weights_feature/part_0"});
+ AddNewNode("Svdf2/SVDF_weights_time/part_0", "Const", {}, {0.1, 0.2, 0.3});
+ AddNewNode("Svdf2/SVDF_weights_time/part_0/read", "Identity",
+ {"Svdf2/SVDF_weights_time/part_0"});
+
+ AddNewNode("Svdf2/f1", "SVDF_F1",
+ {"Input1", "Svdf2/SVDF_weights_feature/part_0/read"});
+ AddNewNode("Svdf2/f2", "SVDF_F2",
+ {"Svdf2/SVDF_weights_time/part_0/read", "Svdf2/f1"});
+ AddNewNode("Svdf2/Relu", "Relu", {"Svdf2/f2"});
+ AddShapeNode("Svdf2/Reshape/shape", {10, 2, -1});
+ AddNewNode("Output2", "Const", {"Svdf2/Relu"});
+ }
+
+ ~ResolveSvdfTest() override {}
+
+ protected:
+ void AddNewNode(const string& name, const string& op,
+ const std::vector<string>& inputs) {
+ NodeDef* node = graph_.add_node();
+ node->set_name(name);
+ node->set_op(op);
+ node->set_device("");
+ for (int i = 0; i < inputs.size(); i++) {
+ node->add_input();
+ node->set_input(i, inputs[i]);
+ }
+ }
+
+ void AddNewNode(const string& name, const string& op,
+ const std::vector<string>& inputs,
+ const std::vector<float>& values) {
+ NodeDef* node = graph_.add_node();
+ node->set_name(name);
+ node->set_op(op);
+ node->set_device("");
+ for (int i = 0; i < inputs.size(); i++) {
+ node->add_input();
+ node->set_input(i, inputs[i]);
+ }
+ // Add the float vector as an attribute to the node.
+ (*node->mutable_attr())["dtype"].set_type(tensorflow::DT_FLOAT);
+ tensorflow::TensorProto* allocated_tensor = new tensorflow::TensorProto;
+ tensorflow::TensorShapeProto* allocated_tesnor_shape =
+ new tensorflow::TensorShapeProto;
+ auto tensor_shape_dim0 = allocated_tesnor_shape->add_dim();
+ tensor_shape_dim0->set_size(values.size());
+ allocated_tensor->set_allocated_tensor_shape(allocated_tesnor_shape);
+ allocated_tensor->set_tensor_content(
+ string(reinterpret_cast<const char*>(values.data()),
+ values.size() * sizeof(float)));
+ (*node->mutable_attr())["value"].set_allocated_tensor(allocated_tensor);
+ }
+
+ void AddShapeNode(const string& name, const std::vector<int>& values) {
+ NodeDef* node = graph_.add_node();
+ node->set_name(name);
+ node->set_op("Const");
+ node->set_device("");
+ // Add the float vector as an attribute to the node.
+ (*node->mutable_attr())["dtype"].set_type(tensorflow::DT_INT32);
+ tensorflow::TensorProto* allocated_tensor = new tensorflow::TensorProto;
+ tensorflow::TensorShapeProto* allocated_tesnor_shape =
+ new tensorflow::TensorShapeProto;
+ auto tensor_shape_dim0 = allocated_tesnor_shape->add_dim();
+ tensor_shape_dim0->set_size(values.size());
+ allocated_tensor->set_allocated_tensor_shape(allocated_tesnor_shape);
+ allocated_tensor->set_tensor_content(
+ string(reinterpret_cast<const char*>(values.data()),
+ values.size() * sizeof(int)));
+ (*node->mutable_attr())["value"].set_allocated_tensor(allocated_tensor);
+ }
+
+ GraphDef graph_;
+ SvdfClusterFactory svdf_cluster_factory_;
+ std::vector<std::unique_ptr<Cluster>> clusters_;
+};
+
+TEST_F(ResolveSvdfTest, TestTranspose2DTensor) {
+ static float matrix[] = {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.};
+ static float expected_transposed_matrix[] = {1., 5., 9., 2., 6., 10.,
+ 3., 7., 11., 4., 8., 12.};
+ float* transposed_matrix = new float[12];
+ Transpose2DTensor(matrix, 3, 4, transposed_matrix);
+
+ std::vector<float> actual;
+ actual.insert(
+ actual.end(), transposed_matrix,
+ transposed_matrix + sizeof(expected_transposed_matrix) / sizeof(float));
+ std::vector<float> expected;
+ expected.insert(expected.end(), expected_transposed_matrix,
+ expected_transposed_matrix +
+ sizeof(expected_transposed_matrix) / sizeof(float));
+ delete[] transposed_matrix;
+}
+
+TEST_F(ResolveSvdfTest, TestResolveSvdfFlow) {
+ std::unordered_map<string, bool> is_node_in_cluster;
+ for (const NodeDef& node : graph_.node()) {
+ is_node_in_cluster[node.name()] = false;
+ }
+
+ std::vector<string> cluster_names;
+ CHECK(FindCluster(svdf_cluster_factory_, graph_, &is_node_in_cluster,
+ &clusters_));
+
+ for (const std::unique_ptr<Cluster>& cluster : clusters_) {
+ cluster_names.push_back(cluster->GetName());
+ cluster->CreateNodes();
+ }
+
+ EXPECT_THAT(cluster_names,
+ testing::UnorderedElementsAreArray({"Svdf1", "Svdf2"}));
+
+ std::vector<string> new_node_names;
+ std::vector<float> content_array(3);
+ for (const std::unique_ptr<Cluster>& cluster : clusters_) {
+ // After CreateNodes in each cluster we have three nodes: Svdf,
+ // weights_feature and weights_time.
+ CHECK_EQ(cluster->GetNewNodes().size(), 3);
+ for (const std::unique_ptr<tensorflow::NodeDef>& node :
+ cluster->GetNewNodes()) {
+ new_node_names.push_back(node->name());
+ if (node->op() == "Const") {
+ CHECK_EQ(node->attr().at("dtype").type(), tensorflow::DT_FLOAT);
+ toco::port::CopyToBuffer(
+ node->attr().at("value").tensor().tensor_content(),
+ reinterpret_cast<char*>(content_array.data()));
+ EXPECT_THAT(content_array,
+ testing::UnorderedElementsAreArray({0.1, 0.2, 0.3}));
+ } else {
+ // Checking the Svdf node attributes (rank and activation type) are
+ // correct.
+ if (node->name() == "Svdf1") {
+ CHECK_EQ(node->attr().at("Rank").i(), 1);
+ } else if (node->name() == "Svdf2") {
+ CHECK_EQ(node->attr().at("Rank").i(), 2);
+ }
+ CHECK_EQ(node->attr().at("ActivationFunction").s(), "Relu");
+ }
+ }
+ }
+ EXPECT_THAT(new_node_names, testing::UnorderedElementsAreArray(
+ {"Svdf2/SVDF_weights_feature/part_0",
+ "Svdf2/SVDF_weights_time/part_0", "Svdf2",
+ "Svdf1/SVDF_weights_feature/part_0",
+ "Svdf1/SVDF_weights_time/part_0", "Svdf1"}));
+}
+
+} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/tensorflow_util.cc b/tensorflow/contrib/lite/toco/tensorflow_util.cc
new file mode 100644
index 0000000000..82e2800ca2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_util.cc
@@ -0,0 +1,197 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tensorflow_util.h"
+
+#include <string.h>
+#include <memory>
+#include <set>
+
+#ifdef GOOGLE_PLATFORM
+#include "file/logging/log_lines.h"
+#endif
+#include "google/protobuf/map.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+using tensorflow::AttrValue;
+using tensorflow::GraphDef;
+
+void LogDumpGraphDef(int log_level, const string& message,
+ const GraphDef& tf_graph) {
+ if (!VLOG_IS_ON(log_level)) {
+ return;
+ }
+ std::set<string> ops;
+ for (const auto& node : tf_graph.node()) {
+ ops.insert(node.op());
+ }
+ string dump;
+ toco::port::AppendF(&dump, R"MSG(
+BEGIN DUMP OF TENSORFLOW GRAPHDEF (%s)
+There are %d nodes.
+There are %zu different op types:
+)MSG", message, tf_graph.node_size(), ops.size());
+ for (const auto& op : ops) {
+ toco::port::AppendF(&dump, " %s\n", op);
+ }
+ dump.append(R"MSG(
+PROTO DUMP
+)MSG");
+ for (const auto& node : tf_graph.node()) {
+ toco::port::AppendF(&dump, R"MSG(
+BEGIN NODE: name = %s
+ op = %s
+ inputs = [
+)MSG", node.name(), node.op());
+ for (const auto& input : node.input()) {
+ toco::port::AppendF(&dump, " %s\n", input);
+ }
+ dump.append(" ]\n");
+ for (const auto& attr : node.attr()) {
+ toco::port::AppendF(&dump, " ATTR: name = %s\n", attr.first);
+ if (attr.second.value_case() == AttrValue::kFunc) {
+ dump.append(" func\n");
+ } else if (attr.second.value_case() == AttrValue::kPlaceholder) {
+ toco::port::AppendF(&dump, " placeholder: %s\n",
+ attr.second.placeholder());
+ } else if (attr.second.value_case() == AttrValue::kS) {
+ dump.append(" string:\n");
+ dump.append(R"MSG(
+ BEGIN EMBEDDED STRING
+)MSG");
+ const auto& lines = absl::StrSplit(attr.second.s(), '\n');
+ for (const auto& line : lines) {
+ toco::port::AppendF(&dump, " %s\n", line);
+ }
+ dump.append(R"MSG(
+ END EMBEDDED STRING
+)MSG");
+ } else if (attr.second.value_case() == AttrValue::kI) {
+ toco::port::AppendF(&dump, " int: %lld\n", attr.second.i());
+ } else if (attr.second.value_case() == AttrValue::kF) {
+ toco::port::AppendF(&dump, " float: %g\n", attr.second.f());
+ } else if (attr.second.value_case() == AttrValue::kB) {
+ toco::port::AppendF(&dump, " bool: %s\n",
+ attr.second.b() ? "true" : "false");
+ } else if (attr.second.value_case() == AttrValue::kType) {
+ toco::port::AppendF(&dump, " type: %s\n",
+ tensorflow::DataType_Name(attr.second.type()));
+ } else if (attr.second.value_case() == AttrValue::kShape) {
+ dump.append(" shape: [ ");
+ const auto& shape = attr.second.shape();
+ for (int i = 0; i < shape.dim_size(); i++) {
+ toco::port::AppendF(&dump, "%lld ", shape.dim(i).size());
+ }
+ dump.append("]\n");
+ } else if (attr.second.value_case() == AttrValue::kTensor) {
+ const auto& tensor = attr.second.tensor();
+ dump.append(" TENSOR:\n");
+ toco::port::AppendF(&dump, " type: %s\n",
+ tensorflow::DataType_Name(tensor.dtype()));
+ const auto& shape = tensor.tensor_shape();
+ dump.append(" shape: [ ");
+ for (int i = 0; i < shape.dim_size(); i++) {
+ toco::port::AppendF(&dump, "%lld ", shape.dim(i).size());
+ }
+ dump.append("]\n");
+ if (!tensor.tensor_content().empty()) {
+ toco::port::AppendF(&dump, " tensor_content: %zu bytes\n",
+ tensor.tensor_content().size());
+ }
+ if (tensor.dtype() == tensorflow::DT_INT32) {
+ CHECK_EQ(0, tensor.tensor_content().size() % sizeof(int32));
+ const int size = tensor.tensor_content().size() / sizeof(int32);
+ std::vector<int32> data(size);
+ toco::port::CopyToBuffer(tensor.tensor_content(),
+ reinterpret_cast<char*>(data.data()));
+ const int kMaxValsToPrint = 4;
+ dump.append(" tensor_content as ints: [ ");
+ for (int i = 0; i < kMaxValsToPrint && i < size; i++) {
+ toco::port::AppendF(&dump, "%d ", data[i]);
+ }
+ if (size > kMaxValsToPrint) {
+ dump.append("... ");
+ }
+ dump.append("]\n");
+ }
+ if (tensor.dtype() == tensorflow::DT_FLOAT) {
+ CHECK_EQ(0, tensor.tensor_content().size() % sizeof(float));
+ const int size = tensor.tensor_content().size() / sizeof(float);
+ std::vector<float> data(size);
+ toco::port::CopyToBuffer(tensor.tensor_content(),
+ reinterpret_cast<char*>(data.data()));
+ const int kMaxValsToPrint = 4;
+ dump.append(" tensor_content as floats: [ ");
+ for (int i = 0; i < kMaxValsToPrint && i < size; i++) {
+ toco::port::AppendF(&dump, "%g ", data[i]);
+ }
+ if (size > kMaxValsToPrint) {
+ dump.append("... ");
+ }
+ dump.append("]\n");
+ }
+ if (tensor.int_val_size()) {
+ toco::port::AppendF(&dump, " int_val: %d ints: [ ",
+ tensor.int_val_size());
+ const int kMaxValsToPrint = 4;
+ for (int i = 0; i < kMaxValsToPrint && i < tensor.int_val_size();
+ i++) {
+ toco::port::AppendF(&dump, "%d ", tensor.int_val(i));
+ }
+ if (tensor.int_val_size() > kMaxValsToPrint) {
+ dump.append("... ");
+ }
+ dump.append("]\n");
+ }
+ if (tensor.float_val_size()) {
+ toco::port::AppendF(&dump, " float_val: %d floats: [ ",
+ tensor.float_val_size());
+ const int kMaxValsToPrint = 4;
+ for (int i = 0; i < kMaxValsToPrint && i < tensor.float_val_size();
+ i++) {
+ toco::port::AppendF(&dump, "%g ", tensor.float_val(i));
+ }
+ if (tensor.float_val_size() > kMaxValsToPrint) {
+ dump.append("... ");
+ }
+ dump.append("]\n");
+ }
+ if (tensor.string_val_size()) {
+ toco::port::AppendF(&dump, " string_val: %d strings\n",
+ tensor.string_val_size());
+ }
+ } else if (attr.second.value_case() == AttrValue::kList) {
+ dump.append(" LIST\n");
+ }
+ }
+ dump.append("END NODE\n");
+ }
+ toco::port::AppendF(&dump, "END DUMP OF TENSORFLOW GRAPHDEF (%s)\n", message);
+#if defined(GOOGLE_PLATFORM)
+ VLOG_LINES(log_level, dump);
+#else
+ VLOG(log_level) << dump;
+#endif
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tensorflow_util.h b/tensorflow/contrib/lite/toco/tensorflow_util.h
new file mode 100644
index 0000000000..152b4f7a72
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_util.h
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+
+namespace toco {
+
+void LogDumpGraphDef(int log_level, const string& message,
+ const tensorflow::GraphDef& tf_graph);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD
new file mode 100644
index 0000000000..e910e3957f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/BUILD
@@ -0,0 +1,142 @@
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+cc_library(
+ name = "operator",
+ srcs = [
+ "operator.cc",
+ ],
+ hdrs = [
+ "builtin_operator.h",
+ "custom_operator.h",
+ "operator.h",
+ "simple_operator.h",
+ ],
+ deps = [
+ ":types",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/memory",
+ "@flatbuffers//:flatbuffers",
+ ],
+)
+
+tf_cc_test(
+ name = "operator_test",
+ srcs = [
+ "operator_test.cc",
+ ],
+ deps = [
+ ":operator",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_googletest//:gtest_main",
+ "@flatbuffers//:flatbuffers",
+ ],
+)
+
+cc_library(
+ name = "types",
+ srcs = [
+ "types.cc",
+ ],
+ hdrs = [
+ "types.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/contrib/lite/toco:model",
+ ],
+)
+
+tf_cc_test(
+ name = "types_test",
+ srcs = [
+ "types_test.cc",
+ ],
+ deps = [
+ ":types",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "export",
+ srcs = [
+ "export.cc",
+ ],
+ hdrs = [
+ "export.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":operator",
+ ":types",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "@com_google_absl//absl/strings",
+ "@flatbuffers//:flatbuffers",
+ ],
+)
+
+tf_cc_test(
+ name = "export_test",
+ srcs = [
+ "export_test.cc",
+ ],
+ deps = [
+ ":export",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "import",
+ srcs = [
+ "import.cc",
+ ],
+ hdrs = [
+ "import.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":operator",
+ ":types",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/contrib/lite/toco:model",
+ "@flatbuffers//:flatbuffers",
+ ],
+)
+
+tf_cc_test(
+ name = "import_test",
+ srcs = [
+ "import_test.cc",
+ ],
+ deps = [
+ ":import",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "@com_google_googletest//:gtest_main",
+ "@flatbuffers//:flatbuffers",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/toco/tflite/builtin_operator.h b/tensorflow/contrib/lite/toco/tflite/builtin_operator.h
new file mode 100644
index 0000000000..93cc79ddb6
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/builtin_operator.h
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_
+
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+
+namespace toco {
+
+namespace tflite {
+
+// Builtin operators have special TF Lite objects describing their options.
+// This class has the boilerplate code for creating those.
+//
+// Template arguments:
+// - T1 must derive from ::toco::Operator.
+// - T2 must be one of TF Lite's objects defining Builtin Options, such as
+// ::tflite::Conv2DOptions.
+template <typename T1, typename T2, ::tflite::BuiltinOptions TfLiteEnum>
+class BuiltinOperator : public BaseOperator {
+ public:
+ using TocoOperator = T1;
+ using TfLiteOptions = T2;
+
+ BuiltinOperator(::tflite::BuiltinOperator op, OperatorType type)
+ : BaseOperator(::tflite::EnumNameBuiltinOperator(op), type) {}
+
+ // Build the configuration object in the given flatbuffer builder. Return
+ // its offset.
+ virtual flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const = 0;
+
+ // Read options from the TF Lite object and set the corresponding values in
+ // the tf.mini operator.
+ virtual void ReadOptions(const TfLiteOptions& opt,
+ TocoOperator* op) const = 0;
+
+ Options Serialize(const Operator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto options = WriteOptions(static_cast<const TocoOperator&>(op), builder);
+ return Options::Builtin(TfLiteEnum, options.Union());
+ }
+
+ std::unique_ptr<Operator> Deserialize(
+ const BuiltinOptions* builtin_options,
+ const CustomOptions* custom_options) const override {
+ auto op = absl::make_unique<TocoOperator>();
+ auto* options = static_cast<const TfLiteOptions*>(builtin_options);
+ if (options) {
+ ReadOptions(*options, op.get());
+ }
+ return std::unique_ptr<Operator>(op.release());
+ }
+};
+
+} // namespace tflite
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/custom_operator.h b/tensorflow/contrib/lite/toco/tflite/custom_operator.h
new file mode 100644
index 0000000000..1a4bfac7d4
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/custom_operator.h
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_
+
+#include "flatbuffers/flexbuffers.h"
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+
+namespace toco {
+
+namespace tflite {
+
+// Custom operators have a generic byte buffer describing their options. This
+// class provides the boilerplate code for populating those options using
+// flexbuffers. Note that most of toco's operators will likely be supported
+// as builtin operators in TF Lite.
+//
+// Template argument T must derive from ::toco::Operator.
+template <typename T>
+class CustomOperator : public BaseOperator {
+ public:
+ using TocoOperator = T;
+ using BaseOperator::BaseOperator;
+
+ // Populate the given flexbuffer with options obtained from the tf.mini
+ // operator.
+ virtual void WriteOptions(const TocoOperator& op,
+ flexbuffers::Builder* fbb) const {}
+
+ // Set options in the given tf.mini operator using values from the flexbuffer
+ // map.
+ virtual void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const {}
+
+ Options Serialize(const Operator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ flexbuffers::Builder fbb;
+ fbb.Map(
+ [&]() { WriteOptions(static_cast<const TocoOperator&>(op), &fbb); });
+ fbb.Finish();
+ return Options::Custom(builder->CreateVector(fbb.GetBuffer()));
+ }
+
+ std::unique_ptr<Operator> Deserialize(
+ const BuiltinOptions* builtin_options,
+ const CustomOptions* custom_options) const override {
+ auto op = absl::make_unique<TocoOperator>();
+ if (custom_options) {
+ auto flexbuffer_map =
+ flexbuffers::GetRoot(custom_options->data(), custom_options->size())
+ .AsMap();
+ ReadOptions(flexbuffer_map, op.get());
+ }
+ return std::unique_ptr<Operator>(op.release());
+ }
+};
+
+} // namespace tflite
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
new file mode 100644
index 0000000000..beda710614
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -0,0 +1,322 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/export.h"
+
+#include "flatbuffers/flexbuffers.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+#include "tensorflow/contrib/lite/toco/tflite/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace toco {
+
+namespace tflite {
+
+using ::tflite::Buffer;
+using ::tflite::BuiltinOperator;
+using ::tflite::BuiltinOperator_CUSTOM;
+using ::tflite::BuiltinOperator_MAX;
+using ::tflite::BuiltinOperator_MIN;
+using ::tflite::CreateBuffer;
+using ::tflite::CreateModel;
+using ::tflite::CreateOperator;
+using ::tflite::CreateTensor;
+using ::tflite::Operator;
+using ::tflite::OperatorCode;
+using ::tflite::SubGraph;
+using ::tflite::Tensor;
+using flatbuffers::FlatBufferBuilder;
+using flatbuffers::Offset;
+using flatbuffers::Vector;
+
+namespace {
+
+details::OperatorKey GetOperatorKey(const ::toco::Operator& op) {
+ string custom_code;
+ if (op.type == OperatorType::kTensorFlowUnsupported) {
+ const TensorFlowUnsupportedOperator& unsupported_op =
+ static_cast<const TensorFlowUnsupportedOperator&>(op);
+ custom_code = unsupported_op.tensorflow_op;
+ }
+ return details::OperatorKey(op.type, custom_code);
+}
+
+} // Anonymous namespace.
+
+namespace details {
+
+void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
+ // First find a list of unique array names.
+ std::set<string> names;
+ for (const auto& array_pair : model.arrays) {
+ names.insert(array_pair.first);
+ }
+
+ // Now assign indices to them and fill in the map.
+ int index = 0;
+ for (const auto& name : names) {
+ (*tensors_map)[name] = index;
+ ++index;
+ }
+}
+
+void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map) {
+ // First find a list of unique operator types.
+ std::set<OperatorKey> keys;
+ for (const auto& op : model.operators) {
+ keys.insert(GetOperatorKey(*op));
+ }
+ // Now assign indices to them and fill in the map.
+ int index = 0;
+ for (const auto& key : keys) {
+ (*operators_map)[key] = index;
+ ++index;
+ }
+}
+} // namespace details
+
+Offset<Vector<Offset<Tensor>>> ExportTensors(
+ const Model& model, const details::TensorsMap& tensors_map,
+ FlatBufferBuilder* builder, std::vector<const Array*>* buffers_to_write) {
+ // In the end we will need to produce a vector sorted by the indices of the
+ // tensors in the tensors_map.
+ std::map<int, Offset<Tensor>> ordered_tensors;
+
+ for (const auto& array_pair : model.arrays) {
+ const string& tensor_name = array_pair.first;
+ const toco::Array& array = *array_pair.second;
+
+ int buffer_index = buffers_to_write->size();
+ auto type = DataType::Serialize(array.data_type);
+ buffers_to_write->push_back(&array);
+
+ std::vector<int> shape;
+ if (array.has_shape()) {
+ for (int d : array.shape().dims()) {
+ shape.push_back(d);
+ }
+ }
+
+ Offset<Vector<float>> min;
+ Offset<Vector<float>> max;
+ Offset<Vector<float>> scale;
+ Offset<Vector<int64_t>> zero_point;
+ if (array.minmax) {
+ min = builder->CreateVector(
+ std::vector<float>{static_cast<float>(array.minmax->min)});
+ max = builder->CreateVector(
+ std::vector<float>{static_cast<float>(array.minmax->max)});
+ }
+ if (array.quantization_params) {
+ scale = builder->CreateVector(std::vector<float>{
+ static_cast<float>(array.quantization_params->scale)});
+ zero_point = builder->CreateVector(
+ std::vector<int64_t>{array.quantization_params->zero_point});
+ }
+ auto q_param = ::tflite::CreateQuantizationParameters(*builder, min, max,
+ scale, zero_point);
+
+ int index = tensors_map.at(tensor_name);
+ ordered_tensors[index] =
+ CreateTensor(*builder, builder->CreateVector(shape), type, buffer_index,
+ builder->CreateString(tensor_name), q_param);
+ }
+
+ std::vector<Offset<Tensor>> tensor_vector;
+ tensor_vector.reserve(ordered_tensors.size());
+ for (const auto& tensor : ordered_tensors) {
+ tensor_vector.push_back(tensor.second);
+ }
+
+ return builder->CreateVector(tensor_vector);
+}
+
+Offset<Vector<int32_t>> ExportInputTensors(
+ const Model& model, const details::TensorsMap& tensors_map,
+ FlatBufferBuilder* builder) {
+ std::vector<int32_t> inputs;
+ for (const auto& input : model.flags.input_arrays()) {
+ inputs.push_back(tensors_map.at(input.name()));
+ }
+ return builder->CreateVector<int32_t>(inputs);
+}
+
+Offset<Vector<int32_t>> ExportOutputTensors(
+ const Model& model, const details::TensorsMap& tensors_map,
+ FlatBufferBuilder* builder) {
+ std::vector<int32_t> outputs;
+ for (const string& output : model.flags.output_arrays()) {
+ outputs.push_back(tensors_map.at(output));
+ }
+ return builder->CreateVector<int32_t>(outputs);
+}
+
+Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
+ const Model& model,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ const details::OperatorsMap& operators_map, FlatBufferBuilder* builder,
+ std::set<string>* error_summary) {
+ // Map from operator name to TF Lite enum value, for all builtins.
+ std::map<string, BuiltinOperator> builtin_ops;
+ for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
+ BuiltinOperator op = static_cast<BuiltinOperator>(i);
+ string name = EnumNameBuiltinOperator(op);
+ if (op != BuiltinOperator_CUSTOM && !name.empty()) {
+ builtin_ops[name] = op;
+ }
+ }
+
+ // We will need to produce a vector of codes in the same order as they
+ // appear in the operators_map.
+ std::map<int, Offset<OperatorCode>> ordered_opcodes;
+
+ for (const auto& op : model.operators) {
+ const details::OperatorKey operator_key = GetOperatorKey(*op);
+ int op_index = operators_map.at(operator_key);
+
+ if (ops_by_type.count(op->type) == 0) {
+ LOG(FATAL) << "Unsupported operator: " << HelpfulOperatorTypeName(*op);
+ }
+
+ string name = ops_by_type.at(op->type)->name();
+ if (builtin_ops.count(name) > 0) {
+ ordered_opcodes[op_index] =
+ CreateOperatorCode(*builder, builtin_ops[name], 0);
+ } else {
+ // If use the custom operation code if it's available in the OperatorKey.
+ if (!operator_key.custom_code.empty()) {
+ name = operator_key.custom_code;
+ }
+ if (error_summary) {
+ error_summary->insert(name);
+ }
+ ordered_opcodes[op_index] = CreateOperatorCode(
+ *builder, BuiltinOperator_CUSTOM, builder->CreateString(name));
+ }
+ }
+
+ std::vector<Offset<OperatorCode>> opcode_vector;
+ opcode_vector.reserve(ordered_opcodes.size());
+ for (const auto& opcode : ordered_opcodes) {
+ opcode_vector.push_back(opcode.second);
+ }
+
+ return builder->CreateVector(opcode_vector);
+}
+
+Offset<Vector<Offset<Operator>>> ExportOperators(
+ const Model& model,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ const details::OperatorsMap& operators_map,
+ const details::TensorsMap& tensors_map, FlatBufferBuilder* builder) {
+ // The operators are in execution order, so we just follow tf.mini order.
+ std::vector<Offset<Operator>> op_vector;
+ for (const auto& op : model.operators) {
+ if (ops_by_type.count(op->type) == 0) {
+ LOG(FATAL) << "Op type '" << OperatorTypeName(op->type)
+ << "' not supported";
+ }
+
+ std::vector<int32_t> inputs;
+ for (const string& input : op->inputs) {
+ inputs.push_back(tensors_map.at(input));
+ }
+
+ std::vector<int32_t> outputs;
+ for (const string& output : op->outputs) {
+ outputs.push_back(tensors_map.at(output));
+ }
+
+ auto options = ops_by_type.at(op->type)->Serialize(*op, builder);
+ int op_index = operators_map.at(GetOperatorKey(*op));
+ // The only supported CustomOptionFormat is FLEXBUFFERS now.
+ op_vector.push_back(CreateOperator(
+ *builder, op_index, builder->CreateVector(inputs),
+ builder->CreateVector(outputs), options.type, options.builtin,
+ options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS));
+ }
+
+ return builder->CreateVector(op_vector);
+}
+
+Offset<Vector<Offset<Buffer>>> ExportBuffers(
+ const Model& model, const std::vector<const Array*>& buffers_to_write,
+ FlatBufferBuilder* builder) {
+ std::vector<Offset<Buffer>> buffer_vector;
+ size_t index = 0;
+ for (const Array* array_ptr : buffers_to_write) {
+ const Array& array = *array_ptr;
+ Offset<Vector<uint8_t>> data_buffer = DataBuffer::Serialize(array, builder);
+ buffer_vector.push_back(CreateBuffer(*builder, data_buffer));
+ index++;
+ }
+ return builder->CreateVector(buffer_vector);
+}
+
+void Export(const Model& model, bool allow_custom_ops,
+ string* output_file_contents) {
+ flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+
+ details::TensorsMap tensors_map;
+ details::LoadTensorsMap(model, &tensors_map);
+
+ details::OperatorsMap operators_map;
+ details::LoadOperatorsMap(model, &operators_map);
+
+ std::vector<const Array*> buffers_to_write;
+ Array empty_array;
+ buffers_to_write.push_back(&empty_array);
+
+ auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write);
+ auto inputs = ExportInputTensors(model, tensors_map, &builder);
+ auto outputs = ExportOutputTensors(model, tensors_map, &builder);
+
+ std::set<string> error_summary;
+ auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
+ &builder, &error_summary);
+ if (!allow_custom_ops && !error_summary.empty()) {
+ LOG(QFATAL) << "Some of the operators in the model are not supported by "
+ "the standard TensorFlow Lite runtime. If you have a custom "
+ "implementation for them you can disable this error with "
+ "--allow_custom_ops. Here is a list of operators for which "
+ "you will need custom implementations: "
+ << absl::StrJoin(error_summary, ", ") << ".";
+ }
+
+ auto ops =
+ ExportOperators(model, ops_by_type, operators_map, tensors_map, &builder);
+
+ // TODO(aselle): add support to toco for multiple subgraphs.
+ auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops);
+ std::vector<flatbuffers::Offset<SubGraph>> subgraphs = {subgraph};
+
+ auto buffers = ExportBuffers(model, buffers_to_write, &builder);
+ auto description = builder.CreateString("TOCO Converted.");
+ auto new_model_location =
+ CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
+ builder.CreateVector(subgraphs), description, buffers);
+ ::tflite::FinishModelBuffer(builder, new_model_location);
+ const uint8_t* buffer = builder.GetBufferPointer();
+ int size = builder.GetSize();
+ *output_file_contents = string(reinterpret_cast<const char*>(buffer), size);
+}
+
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
new file mode 100644
index 0000000000..44012b7126
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
+
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+namespace tflite {
+
+// Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
+// result in the given string.
+void Export(const Model& model, bool allow_custom_ops,
+ string* output_file_contents);
+// This if backward-compatibility.
+inline void Export(const Model& model, string* output_file_contents) {
+ Export(model, true, output_file_contents);
+}
+
+namespace details {
+
+// A maps from tensor name to its final position in the TF Lite buffer.
+using TensorsMap = std::unordered_map<string, int>;
+
+// A key to identify an operator.
+// Only when `type` is `kTensorFlowUnsupported`, `custom_code` is filled to
+// identify which operation is used.
+struct OperatorKey {
+ OperatorKey(OperatorType type, const std::string& custom_code)
+ : type(type), custom_code(custom_code) {}
+ const OperatorType type;
+ const std::string custom_code;
+
+ bool operator<(const OperatorKey& other) const {
+ if (type < other.type) return true;
+ if (type > other.type) return false;
+ return custom_code < other.custom_code;
+ }
+
+ bool operator==(const OperatorKey& other) const {
+ return type == other.type && custom_code == other.custom_code;
+ }
+
+ struct Hash {
+ std::size_t operator()(const OperatorKey& key) const {
+ return std::hash<size_t>()(static_cast<size_t>(key.type)) ^
+ std::hash<std::string>()(key.custom_code);
+ }
+ };
+};
+
+// A maps from operator type to its final position in the TF Lite buffer.
+using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
+
+void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
+void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map);
+
+} // namespace details
+} // namespace tflite
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
new file mode 100644
index 0000000000..e395645383
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/export.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace toco {
+
+namespace tflite {
+namespace {
+
+class ExportTest : public ::testing::Test {
+ protected:
+ // This is a very simplistic model. We are not interested in testing all the
+ // details here, since tf.mini's testing framework will be exercising all the
+ // conversions multiple times, and the conversion of operators is tested by
+ // separate unittests.
+ void BuildTestModel() {
+ input_model_.GetOrCreateArray("tensor_one");
+ input_model_.GetOrCreateArray("tensor_two");
+ input_model_.operators.emplace_back(new ConvOperator);
+ input_model_.operators.emplace_back(new AddOperator);
+ auto unsupported_operator = new TensorFlowUnsupportedOperator;
+ unsupported_operator->tensorflow_op = "MyCrazyOp";
+ input_model_.operators.emplace_back(unsupported_operator);
+ }
+
+ Model input_model_;
+};
+
+TEST_F(ExportTest, LoadTensorsMap) {
+ BuildTestModel();
+
+ details::TensorsMap tensors;
+ details::LoadTensorsMap(input_model_, &tensors);
+ EXPECT_EQ(0, tensors["tensor_one"]);
+ EXPECT_EQ(1, tensors["tensor_two"]);
+}
+
+TEST_F(ExportTest, LoadOperatorsMap) {
+ BuildTestModel();
+
+ details::OperatorsMap operators;
+ details::LoadOperatorsMap(input_model_, &operators);
+ EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "")]);
+ EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "")]);
+ EXPECT_EQ(2, operators[details::OperatorKey(
+ OperatorType::kTensorFlowUnsupported, "MyCrazyOp")]);
+}
+
+// TODO(ahentz): tests for tensors, inputs, outpus, opcodes and operators.
+
+} // namespace
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc
new file mode 100644
index 0000000000..bbf201fd28
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/import.cc
@@ -0,0 +1,183 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/import.h"
+
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+#include "tensorflow/contrib/lite/toco/tflite/types.h"
+
+namespace toco {
+
+namespace tflite {
+
+namespace details {
+void LoadTensorsTable(const ::tflite::Model& input_model,
+ TensorsTable* tensors_table) {
+ // TODO(aselle): add support to toco for multiple subgraphs.
+ auto tensors = (*input_model.subgraphs())[0]->tensors();
+ if (!tensors) return;
+ for (const auto* tensor : *tensors) {
+ tensors_table->push_back(tensor->name()->c_str());
+ }
+}
+
+void LoadOperatorsTable(const ::tflite::Model& input_model,
+ OperatorsTable* operators_table) {
+ auto opcodes = input_model.operator_codes();
+ if (!opcodes) return;
+ for (const auto* opcode : *opcodes) {
+ if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) {
+ operators_table->push_back(
+ EnumNameBuiltinOperator(opcode->builtin_code()));
+ } else {
+ operators_table->push_back(opcode->custom_code()->c_str());
+ }
+ }
+}
+} // namespace details
+
+void ImportTensors(const ::tflite::Model& input_model, Model* model) {
+ auto tensors = (*input_model.subgraphs())[0]->tensors();
+ auto* buffers = input_model.buffers();
+ // auto tensors = input_model.tensors();
+ if (!tensors) return;
+ for (const auto* input_tensor : *tensors) {
+ Array& array = model->GetOrCreateArray(input_tensor->name()->c_str());
+ array.data_type = DataType::Deserialize(input_tensor->type());
+ int buffer_index = input_tensor->buffer();
+ auto* buffer = buffers->Get(buffer_index);
+ DataBuffer::Deserialize(*input_tensor, *buffer, &array);
+
+ auto shape = input_tensor->shape();
+ if (shape) {
+ for (int i = 0; i < shape->Length(); ++i) {
+ auto d = shape->Get(i);
+ array.mutable_shape()->mutable_dims()->push_back(d);
+ }
+ }
+
+ auto quantization = input_tensor->quantization();
+ if (quantization) {
+ // Note that tf.mini only supports a single quantization parameters for
+ // the whole array.
+ if (quantization->min() && quantization->max()) {
+ CHECK_EQ(1, quantization->min()->Length());
+ CHECK_EQ(1, quantization->max()->Length());
+ MinMax& minmax = array.GetOrCreateMinMax();
+ minmax.min = quantization->min()->Get(0);
+ minmax.max = quantization->max()->Get(0);
+ }
+ if (quantization->scale() && quantization->zero_point()) {
+ CHECK_EQ(1, quantization->scale()->Length());
+ CHECK_EQ(1, quantization->zero_point()->Length());
+ QuantizationParams& q = array.GetOrCreateQuantizationParams();
+ q.scale = quantization->scale()->Get(0);
+ q.zero_point = quantization->zero_point()->Get(0);
+ }
+ }
+ }
+}
+
+void ImportOperators(
+ const ::tflite::Model& input_model,
+ const std::map<string, std::unique_ptr<BaseOperator>>& ops_by_name,
+ const details::TensorsTable& tensors_table,
+ const details::OperatorsTable& operators_table, Model* model) {
+ // TODO(aselle): add support for multiple subgraphs.
+ auto ops = (*input_model.subgraphs())[0]->operators();
+
+ if (!ops) return;
+ for (const auto* input_op : *ops) {
+ int index = input_op->opcode_index();
+ if (index < 0 || index > operators_table.size()) {
+ LOG(FATAL) << "Index " << index << " must be between zero and "
+ << operators_table.size();
+ }
+ string opname = operators_table.at(index);
+ if (ops_by_name.count(opname) == 0) {
+ LOG(FATAL) << "Op '" << opname << "' not supported";
+ }
+
+ auto new_op = ops_by_name.at(opname)->Deserialize(
+ input_op->builtin_options(), input_op->custom_options());
+ model->operators.emplace_back(new_op.release());
+ auto* op = model->operators.back().get();
+
+ auto inputs = input_op->inputs();
+ for (int i = 0; i < inputs->Length(); i++) {
+ auto input_index = inputs->Get(i);
+ const string& input_name = tensors_table.at(input_index);
+ op->inputs.push_back(input_name);
+ }
+ auto outputs = input_op->outputs();
+ for (int i = 0; i < outputs->Length(); i++) {
+ auto output_index = outputs->Get(i);
+ const string& output_name = tensors_table.at(output_index);
+ op->outputs.push_back(output_name);
+ }
+ }
+}
+
+void ImportIOTensors(const ::tflite::Model& input_model,
+ const details::TensorsTable& tensors_table, Model* model) {
+ auto inputs = (*input_model.subgraphs())[0]->inputs();
+ if (inputs) {
+ for (int input : *inputs) {
+ const string& input_name = tensors_table.at(input);
+ model->flags.add_input_arrays()->set_name(input_name);
+ }
+ }
+
+ auto outputs = (*input_model.subgraphs())[0]->outputs();
+ if (outputs) {
+ for (int output : *outputs) {
+ const string& output_name = tensors_table.at(output);
+ model->flags.add_output_arrays(output_name);
+ }
+ }
+}
+
+std::unique_ptr<Model> Import(const ModelFlags& model_flags,
+ const string& input_file_contents) {
+ const ::tflite::Model* input_model =
+ ::tflite::GetModel(input_file_contents.data());
+
+ // Full list of all known operators.
+ const auto ops_by_name = BuildOperatorByNameMap();
+
+ if (input_model->subgraphs()->size() != 1) {
+ LOG(FATAL) << "# of subgraphs in tflite should be exactly 1 for now.";
+ }
+ std::unique_ptr<Model> model;
+ model.reset(new Model);
+
+ details::TensorsTable tensors_table;
+ details::LoadTensorsTable(*input_model, &tensors_table);
+
+ details::OperatorsTable operators_table;
+ details::LoadOperatorsTable(*input_model, &operators_table);
+
+ ImportTensors(*input_model, model.get());
+ ImportOperators(*input_model, ops_by_name, tensors_table, operators_table,
+ model.get());
+ ImportIOTensors(*input_model, tensors_table, model.get());
+
+ return model;
+}
+
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/import.h b/tensorflow/contrib/lite/toco/tflite/import.h
new file mode 100644
index 0000000000..3c27a2843c
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/import.h
@@ -0,0 +1,49 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_
+
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+namespace tflite {
+
+// Parse the given string as TF Lite flatbuffer and return a new tf.mini model.
+std::unique_ptr<Model> Import(const ModelFlags &model_flags,
+ const string &input_file_contents);
+
+namespace details {
+
+// The names of all tensors found in a TF Lite model.
+using TensorsTable = std::vector<string>;
+
+// The names of all operators found in TF Lite model. If the operator is
+// builtin, the string representation of the corresponding enum value is used
+// as name.
+using OperatorsTable = std::vector<string>;
+
+void LoadTensorsTable(const ::tflite::Model &input_model,
+ TensorsTable *tensors_table);
+void LoadOperatorsTable(const ::tflite::Model &input_model,
+ OperatorsTable *operators_table);
+
+} // namespace details
+} // namespace tflite
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/import_test.cc b/tensorflow/contrib/lite/toco/tflite/import_test.cc
new file mode 100644
index 0000000000..309fa6d7f6
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/import_test.cc
@@ -0,0 +1,141 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/import.h"
+
+#include "flatbuffers/flexbuffers.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace toco {
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+class ImportTest : public ::testing::Test {
+ protected:
+ template <typename T>
+ flatbuffers::Offset<flatbuffers::Vector<unsigned char>> CreateDataVector(
+ const std::vector<T>& data) {
+ return builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.data()),
+ sizeof(T) * data.size());
+ }
+ // This is a very simplistic model. We are not interested in testing all the
+ // details here, since tf.mini's testing framework will be exercising all the
+ // conversions multiple times, and the conversion of operators is tested by
+ // separate unittests.
+ void BuildTestModel() {
+ // The tensors
+ auto q = ::tflite::CreateQuantizationParameters(
+ builder_,
+ /*min=*/builder_.CreateVector<float>({0.1f}),
+ /*max=*/builder_.CreateVector<float>({0.2f}),
+ /*scale=*/builder_.CreateVector<float>({0.3f}),
+ /*zero_point=*/builder_.CreateVector<int64_t>({100ll}));
+ auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector<float>({}));
+ auto buf1 =
+ ::tflite::CreateBuffer(builder_, CreateDataVector<float>({1.0f, 2.0f}));
+ auto buf2 =
+ ::tflite::CreateBuffer(builder_, CreateDataVector<float>({3.0f}));
+ auto buffers = builder_.CreateVector(
+ std::vector<flatbuffers::Offset<::tflite::Buffer>>({buf0, buf1, buf2}));
+ auto t1 = ::tflite::CreateTensor(builder_,
+ builder_.CreateVector<int>({1, 2, 3, 4}),
+ ::tflite::TensorType_FLOAT32, 1,
+ builder_.CreateString("tensor_one"), q);
+ auto t2 =
+ ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({2, 1}),
+ ::tflite::TensorType_FLOAT32, 2,
+ builder_.CreateString("tensor_two"), q);
+ auto tensors = builder_.CreateVector(
+ std::vector<flatbuffers::Offset<::tflite::Tensor>>({t1, t2}));
+
+ // The operator codes.
+ auto c1 =
+ ::tflite::CreateOperatorCode(builder_, ::tflite::BuiltinOperator_CUSTOM,
+ builder_.CreateString("custom_op_one"));
+ auto c2 = ::tflite::CreateOperatorCode(
+ builder_, ::tflite::BuiltinOperator_CONV_2D, 0);
+ auto opcodes = builder_.CreateVector(
+ std::vector<flatbuffers::Offset<::tflite::OperatorCode>>({c1, c2}));
+
+ auto subgraph = ::tflite::CreateSubGraph(builder_, tensors, 0, 0, 0);
+ std::vector<flatbuffers::Offset<::tflite::SubGraph>> subgraph_vector(
+ {subgraph});
+ auto subgraphs = builder_.CreateVector(subgraph_vector);
+ auto s = builder_.CreateString("");
+ builder_.Finish(::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
+ opcodes, subgraphs, s, buffers));
+
+ input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
+ }
+ string InputModelAsString() {
+ return string(reinterpret_cast<char*>(builder_.GetBufferPointer()),
+ builder_.GetSize());
+ }
+ flatbuffers::FlatBufferBuilder builder_;
+ // const uint8_t* buffer_ = nullptr;
+ const ::tflite::Model* input_model_ = nullptr;
+};
+
+TEST_F(ImportTest, LoadTensorsTable) {
+ BuildTestModel();
+
+ details::TensorsTable tensors;
+ details::LoadTensorsTable(*input_model_, &tensors);
+ EXPECT_THAT(tensors, ElementsAre("tensor_one", "tensor_two"));
+}
+
+TEST_F(ImportTest, LoadOperatorsTable) {
+ BuildTestModel();
+
+ details::OperatorsTable operators;
+ details::LoadOperatorsTable(*input_model_, &operators);
+ EXPECT_THAT(operators, ElementsAre("custom_op_one", "CONV_2D"));
+}
+
+TEST_F(ImportTest, Tensors) {
+ BuildTestModel();
+
+ auto model = Import(ModelFlags(), InputModelAsString());
+
+ ASSERT_GT(model->arrays.count("tensor_one"), 0);
+ Array& a1 = model->GetArray("tensor_one");
+ EXPECT_EQ(ArrayDataType::kFloat, a1.data_type);
+ EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data,
+ ElementsAre(1.0f, 2.0f));
+ ASSERT_TRUE(a1.has_shape());
+ EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 3, 4));
+
+ const auto& mm = a1.minmax;
+ ASSERT_TRUE(mm.get());
+ EXPECT_FLOAT_EQ(0.1, mm->min);
+ EXPECT_FLOAT_EQ(0.2, mm->max);
+
+ const auto& q = a1.quantization_params;
+ ASSERT_TRUE(q.get());
+ EXPECT_FLOAT_EQ(0.3, q->scale);
+ EXPECT_EQ(100, q->zero_point);
+}
+
+// TODO(ahentz): still need tests for Operators and IOTensors.
+
+} // namespace
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
new file mode 100644
index 0000000000..8a33500ddc
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -0,0 +1,627 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+
+#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
+#include "tensorflow/contrib/lite/toco/tflite/custom_operator.h"
+#include "tensorflow/contrib/lite/toco/tflite/simple_operator.h"
+#include "tensorflow/contrib/lite/toco/tflite/types.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+
+namespace toco {
+
+namespace tflite {
+
+class AveragePool
+ : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
+ ::tflite::BuiltinOptions_Pool2DOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto padding = Padding::Serialize(op.padding.type);
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
+ op.stride_height, op.kwidth,
+ op.kheight, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->padding.type = Padding::Deserialize(options.padding());
+ op->stride_width = options.stride_w();
+ op->stride_height = options.stride_h();
+ op->kwidth = options.filter_width();
+ op->kheight = options.filter_height();
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class Convolution
+ : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
+ ::tflite::BuiltinOptions_Conv2DOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto padding = Padding::Serialize(op.padding.type);
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
+ op.stride_height, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->padding.type = Padding::Deserialize(options.padding());
+ op->stride_width = options.stride_w();
+ op->stride_height = options.stride_h();
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class DepthwiseConvolution
+ : public BuiltinOperator<DepthwiseConvOperator,
+ ::tflite::DepthwiseConv2DOptions,
+ ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto padding = Padding::Serialize(op.padding.type);
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateDepthwiseConv2DOptions(
+ *builder, padding, op.stride_width, op.stride_height,
+ op.depth_multiplier, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->padding.type = Padding::Deserialize(options.padding());
+ op->stride_width = options.stride_w();
+ op->stride_height = options.stride_h();
+ op->depth_multiplier = options.depth_multiplier();
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
+ ::tflite::BuiltinOptions_AddOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateAddOptions(*builder, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class Cast : public CustomOperator<CastOperator> {
+ public:
+ using CustomOperator::CustomOperator;
+ void WriteOptions(const TocoOperator& op,
+ flexbuffers::Builder* fbb) const override {
+ fbb->Int("src_data_type", DataType::Serialize(op.src_data_type));
+ fbb->Int("dst_data_type", DataType::Serialize(op.dst_data_type));
+ }
+ void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
+ op->src_data_type = DataType::Deserialize(m["src_data_type"].AsInt64());
+ op->dst_data_type = DataType::Deserialize(m["dst_data_type"].AsInt64());
+ }
+};
+
+class Concatenation
+ : public BuiltinOperator<ConcatenationOperator,
+ ::tflite::ConcatenationOptions,
+ ::tflite::BuiltinOptions_ConcatenationOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateConcatenationOptions(*builder, op.concat_dim);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->concat_dim = options.axis();
+ }
+};
+
+class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
+ public:
+ using CustomOperator::CustomOperator;
+ void WriteOptions(const TocoOperator& op,
+ flexbuffers::Builder* fbb) const override {
+ fbb->Int("block_size", op.block_size);
+ }
+ void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
+ op->block_size = m["block_size"].AsInt64();
+ }
+};
+
+class FakeQuant : public CustomOperator<FakeQuantOperator> {
+ public:
+ using CustomOperator::CustomOperator;
+ void WriteOptions(const TocoOperator& op,
+ flexbuffers::Builder* fbb) const override {
+ fbb->Float("min", op.minmax->min);
+ fbb->Float("max", op.minmax->max);
+ }
+ void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
+ auto* minmax = new MinMax;
+ minmax->min = m["min"].AsFloat();
+ minmax->max = m["max"].AsFloat();
+ op->minmax.reset(minmax);
+ }
+};
+
+class FullyConnected
+ : public BuiltinOperator<FullyConnectedOperator,
+ ::tflite::FullyConnectedOptions,
+ ::tflite::BuiltinOptions_FullyConnectedOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateFullyConnectedOptions(*builder, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
+ ::tflite::BuiltinOptions_SVDFOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ op->rank = options.rank();
+ }
+};
+
+class L2Normalization
+ : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
+ ::tflite::BuiltinOptions_L2NormOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateL2NormOptions(*builder, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
+ ::tflite::BuiltinOptions_Pool2DOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto padding = Padding::Serialize(op.padding.type);
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
+ op.stride_height, op.kwidth,
+ op.kheight, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->padding.type = Padding::Deserialize(options.padding());
+ op->stride_width = options.stride_w();
+ op->stride_height = options.stride_h();
+ op->kwidth = options.filter_width();
+ op->kheight = options.filter_height();
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class LocalResponseNormalization
+ : public BuiltinOperator<
+ LocalResponseNormalizationOperator,
+ ::tflite::LocalResponseNormalizationOptions,
+ ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateLocalResponseNormalizationOptions(
+ *builder, op.range, op.bias, op.alpha, op.beta);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->range = options.radius();
+ op->bias = options.bias();
+ op->alpha = options.alpha();
+ op->beta = options.beta();
+ }
+};
+
+class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
+ ::tflite::BuiltinOptions_Pool2DOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto padding = Padding::Serialize(op.padding.type);
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
+ op.stride_height, op.kwidth,
+ op.kheight, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->padding.type = Padding::Deserialize(options.padding());
+ op->stride_width = options.stride_w();
+ op->stride_height = options.stride_h();
+ op->kwidth = options.filter_width();
+ op->kheight = options.filter_height();
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
+ ::tflite::BuiltinOptions_MulOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateMulOptions(*builder, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class Reshape
+ : public BuiltinOperator<TensorFlowReshapeOperator,
+ ::tflite::ReshapeOptions,
+ ::tflite::BuiltinOptions_ReshapeOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateReshapeOptions(*builder,
+ builder->CreateVector(op.shape));
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->shape.insert(op->shape.end(), options.new_shape()->begin(),
+ options.new_shape()->end());
+ }
+};
+
+class Softmax
+ : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
+ ::tflite::BuiltinOptions_SoftmaxOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->beta = options.beta();
+ }
+};
+
+class SpaceToDepth
+ : public BuiltinOperator<SpaceToDepthOperator,
+ ::tflite::SpaceToDepthOptions,
+ ::tflite::BuiltinOptions_SpaceToDepthOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->block_size = options.block_size();
+ }
+};
+
+class Split : public CustomOperator<TensorFlowSplitOperator> {
+ public:
+ using CustomOperator::CustomOperator;
+ void WriteOptions(const TocoOperator& op,
+ flexbuffers::Builder* fbb) const override {
+ fbb->Int("num_split", op.num_split);
+ }
+ void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
+ op->num_split = m["num_split"].AsInt64();
+ }
+};
+
+class TensorFlowUnsupported : public BaseOperator {
+ public:
+ using BaseOperator::BaseOperator;
+
+ Options Serialize(const Operator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto fbb =
+ WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
+ if (fbb) {
+ return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
+ } else {
+ return Options::Custom(0);
+ }
+ }
+
+ std::unique_ptr<Operator> Deserialize(
+ const BuiltinOptions* builtin_options,
+ const CustomOptions* custom_options) const override {
+ auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
+ if (custom_options) {
+ auto flexbuffer_map =
+ flexbuffers::GetRoot(custom_options->data(), custom_options->size())
+ .AsMap();
+ ReadOptions(flexbuffer_map, op.get());
+ }
+ return std::unique_ptr<Operator>(op.release());
+ }
+
+ std::unique_ptr<flexbuffers::Builder> WriteOptions(
+ const TensorFlowUnsupportedOperator& op) const {
+ auto fbb = absl::make_unique<flexbuffers::Builder>();
+
+ ::tensorflow::NodeDef node_def;
+ if (!node_def.ParseFromString(op.tensorflow_node_def)) {
+ LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
+ return std::unique_ptr<flexbuffers::Builder>();
+ }
+
+ bool has_valid_attr = false;
+ size_t map_start = fbb->StartMap();
+ for (const auto& pair : node_def.attr()) {
+ const char* key = pair.first.c_str();
+ const auto& attr = pair.second;
+ switch (attr.value_case()) {
+ case ::tensorflow::AttrValue::kS:
+ fbb->String(key, attr.s());
+ has_valid_attr = true;
+ break;
+ case ::tensorflow::AttrValue::kI:
+ fbb->Int(key, attr.i());
+ has_valid_attr = true;
+ break;
+ case ::tensorflow::AttrValue::kF:
+ fbb->Float(key, attr.f());
+ has_valid_attr = true;
+ break;
+ case ::tensorflow::AttrValue::kB:
+ fbb->Bool(key, attr.b());
+ has_valid_attr = true;
+ break;
+ default:
+ LOG(WARNING) << "Ignoring unsupported attribute type with key '"
+ << key << "'";
+ break;
+ }
+ }
+ if (!has_valid_attr) {
+ return std::unique_ptr<flexbuffers::Builder>();
+ }
+ fbb->EndMap(map_start);
+ fbb->Finish();
+ return std::unique_ptr<flexbuffers::Builder>(fbb.release());
+ }
+
+ void ReadOptions(const flexbuffers::Map& m,
+ TensorFlowUnsupportedOperator* op) const {
+ ::tensorflow::NodeDef node_def;
+ auto attr = node_def.mutable_attr();
+
+ const auto& keys = m.Keys();
+ for (size_t i = 0; i < keys.size(); ++i) {
+ const auto key = keys[i].AsKey();
+ const auto& value = m[key];
+ switch (value.GetType()) {
+ case flexbuffers::TYPE_STRING:
+ (*attr)[key].set_s(value.AsString().c_str());
+ break;
+ case flexbuffers::TYPE_INT:
+ (*attr)[key].set_i(value.AsInt64());
+ break;
+ case flexbuffers::TYPE_FLOAT:
+ (*attr)[key].set_f(value.AsFloat());
+ break;
+ case flexbuffers::TYPE_BOOL:
+ (*attr)[key].set_b(value.AsBool());
+ break;
+ default:
+ LOG(WARNING) << "Ignoring unsupported attribute type with key '"
+ << key << "'";
+ break;
+ }
+ }
+ node_def.SerializeToString(&op->tensorflow_node_def);
+ }
+};
+
+namespace {
+// Build a vector containing all the known operators.
+std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
+ std::vector<std::unique_ptr<BaseOperator>> ops;
+
+ // Builtin Operators.
+ ops.emplace_back(new Add(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
+ ops.emplace_back(new AveragePool(::tflite::BuiltinOperator_AVERAGE_POOL_2D,
+ OperatorType::kAveragePool));
+ ops.emplace_back(new Concatenation(::tflite::BuiltinOperator_CONCATENATION,
+ OperatorType::kConcatenation));
+ ops.emplace_back(
+ new Convolution(::tflite::BuiltinOperator_CONV_2D, OperatorType::kConv));
+ ops.emplace_back(
+ new DepthwiseConvolution(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
+ OperatorType::kDepthwiseConv));
+ ops.emplace_back(new FullyConnected(::tflite::BuiltinOperator_FULLY_CONNECTED,
+ OperatorType::kFullyConnected));
+ ops.emplace_back(
+ new L2Normalization(::tflite::BuiltinOperator_L2_NORMALIZATION,
+ OperatorType::kL2Normalization));
+ ops.emplace_back(
+ new L2Pool(::tflite::BuiltinOperator_L2_POOL_2D, OperatorType::kL2Pool));
+ ops.emplace_back(new LocalResponseNormalization(
+ ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
+ OperatorType::kLocalResponseNormalization));
+ ops.emplace_back(new MaxPool(::tflite::BuiltinOperator_MAX_POOL_2D,
+ OperatorType::kMaxPool));
+ ops.emplace_back(new Mul(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
+ ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE,
+ OperatorType::kTensorFlowReshape));
+ ops.emplace_back(
+ new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax));
+ ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH,
+ OperatorType::kSpaceToDepth));
+ ops.emplace_back(
+ new Svdf(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
+
+ // Custom Operators.
+ ops.emplace_back(new Cast("CAST", OperatorType::kCast));
+ ops.emplace_back(
+ new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
+ ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant));
+ ops.emplace_back(new Split("SPLIT", OperatorType::kTensorFlowSplit));
+ ops.emplace_back(new TensorFlowUnsupported(
+ "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported));
+
+ // There operators are supported by Toco, but not by TF Lite, and has no
+ // attributes.
+ ops.emplace_back(new SimpleOperator<TensorFlowRsqrtOperator>(
+ "RSQRT", OperatorType::kTensorFlowRsqrt));
+ ops.emplace_back(
+ new SimpleOperator<TensorFlowRsqrtOperator>("DIV", OperatorType::kDiv));
+
+ // Simple Operators.
+ ops.emplace_back(new SimpleOperator<DequantizeOperator>(
+ "DEQUANTIZE", OperatorType::kDequantize));
+ ops.emplace_back(
+ new SimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor));
+ ops.emplace_back(
+ new SimpleOperator<GatherOperator>("GATHER", OperatorType::kGather));
+ ops.emplace_back(
+ new SimpleOperator<ReluOperator>("RELU", OperatorType::kRelu));
+ ops.emplace_back(
+ new SimpleOperator<Relu1Operator>("RELU1", OperatorType::kRelu1));
+ ops.emplace_back(
+ new SimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6));
+ ops.emplace_back(new SimpleOperator<ResizeBilinearOperator>(
+ "RESIZE_BILINEAR", OperatorType::kResizeBilinear));
+ ops.emplace_back(new SimpleOperator<LogisticOperator>(
+ "LOGISTIC", OperatorType::kLogistic));
+ ops.emplace_back(
+ new SimpleOperator<TanhOperator>("TANH", OperatorType::kTanh));
+
+ return ops;
+}
+} // namespace
+
+std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
+ std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
+
+ std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
+ for (auto& op : ops) {
+ result[op->type()] = std::move(op);
+ }
+
+ return result;
+}
+
+std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap() {
+ std::map<string, std::unique_ptr<BaseOperator>> result;
+
+ std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
+ for (auto& op : ops) {
+ result[op->name()] = std::move(op);
+ }
+
+ return result;
+}
+
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
new file mode 100644
index 0000000000..37df302d46
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -0,0 +1,89 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
+
+#include "flatbuffers/flatbuffers.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+namespace tflite {
+
+class BaseOperator;
+
+// Return a map contained all knwo TF Lite Operators, keyed by their names.
+std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap();
+
+// Return a map contained all knwo TF Lite Operators, keyed by the type of
+// their tf.mini counterparts.
+std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap();
+
+// These are the flatbuffer types for custom and builtin options.
+using CustomOptions = flatbuffers::Vector<uint8_t>;
+using BuiltinOptions = void;
+
+// A simple wrapper around the flatbuffer objects used to describe options that
+// configure operators.
+struct Options {
+ // Build custom options.
+ static Options Custom(flatbuffers::Offset<CustomOptions> offset) {
+ return {::tflite::BuiltinOptions_NONE, 0, offset};
+ }
+
+ // Build builtin options of the given type.
+ static Options Builtin(::tflite::BuiltinOptions type,
+ flatbuffers::Offset<BuiltinOptions> offset) {
+ return {type, offset, 0};
+ }
+
+ ::tflite::BuiltinOptions type;
+ flatbuffers::Offset<BuiltinOptions> builtin;
+ flatbuffers::Offset<CustomOptions> custom;
+};
+
+// A BaseOperator encapsulates the relationship between operators in tf.mini
+// and TF lite, and provides methods for converting between those two formats.
+class BaseOperator {
+ public:
+ // Build an operator with the given TF Lite name and tf.mini type.
+ BaseOperator(const string& name, OperatorType type)
+ : name_(name), type_(type) {}
+ virtual ~BaseOperator() = default;
+
+ string name() const { return name_; }
+ OperatorType type() const { return type_; }
+
+ // Given a tf.mini operator, create the corresponding flatbuffer options and
+ // return their offsets.
+ virtual Options Serialize(const Operator& op,
+ flatbuffers::FlatBufferBuilder* builder) const = 0;
+
+ // Read TF Lite options and create the appropriate tf.mini operator.
+ virtual std::unique_ptr<Operator> Deserialize(
+ const BuiltinOptions* builtin_options,
+ const CustomOptions* custom_options) const = 0;
+
+ private:
+ string name_;
+ OperatorType type_;
+};
+
+} // namespace tflite
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
new file mode 100644
index 0000000000..543a9bd06c
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -0,0 +1,370 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+
+#include "flatbuffers/flexbuffers.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+
+namespace toco {
+
+namespace tflite {
+namespace {
+
+class OperatorTest : public ::testing::Test {
+ protected:
+ // Return the operator for the given name and type.
+ const BaseOperator& GetOperator(const string& name, OperatorType type) {
+ using OpsByName = std::map<string, std::unique_ptr<BaseOperator>>;
+ using OpsByType = std::map<OperatorType, std::unique_ptr<BaseOperator>>;
+
+ static auto* by_name = new OpsByName(BuildOperatorByNameMap());
+ static auto* by_type = new OpsByType(BuildOperatorByTypeMap());
+
+ // Make sure the two maps were consitently built.
+ CHECK(by_name->count(name)) << "No operator for '" << name << "'.";
+ BaseOperator* op1 = by_name->at(name).get();
+ CHECK(op1->type() == type) << "while verifying '" << name << "'.";
+
+ CHECK(by_type->count(type))
+ << "No operator for '" << OperatorTypeName(type) << "'.";
+ BaseOperator* op2 = by_type->at(type).get();
+ CHECK(op2->name() == name)
+ << "while verifying '" << OperatorTypeName(type) << "'.";
+
+ return *op1;
+ }
+
+ // Use the given BaseOperator to serialize the tf.mini operator into a set of
+ // TF Lite options. Proceed to deserialize the options back into a new
+ // tf.mini operator, which is then returned. If `options` is given, it will
+ // be populated with the serialized options.
+ template <typename T>
+ std::unique_ptr<T> SerializeAndDeserialize(const BaseOperator& op,
+ const T& toco_op,
+ Options* options = nullptr) {
+ flatbuffers::FlatBufferBuilder builder;
+ Options input_options = op.Serialize(toco_op, &builder);
+
+ if (options) {
+ *options = input_options;
+ }
+
+ builder.Finish(CreateOperator(builder, 0, 0, 0, input_options.type,
+ input_options.builtin, input_options.custom,
+ ::tflite::CustomOptionsFormat_FLEXBUFFERS));
+ auto* output_options =
+ flatbuffers::GetRoot<::tflite::Operator>(builder.GetBufferPointer());
+ auto new_toco_op = op.Deserialize(output_options->builtin_options(),
+ output_options->custom_options());
+
+ CHECK(dynamic_cast<T*>(new_toco_op.get()))
+ << "Cannot cast " << HelpfulOperatorTypeName(*new_toco_op) << " to "
+ << HelpfulOperatorTypeName(toco_op);
+
+ return std::unique_ptr<T>(dynamic_cast<T*>(new_toco_op.release()));
+ }
+
+ // Verify serialization and deserialization of simple operators (those
+ // that don't have any configuration parameters).
+ template <typename T>
+ void CheckSimpleOperator(const string& name, OperatorType type) {
+ Options options;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator(name, type), T(), &options);
+
+ ASSERT_EQ(0, options.builtin.o);
+ ASSERT_EQ(0, options.custom.o);
+ ASSERT_EQ(::tflite::BuiltinOptions_NONE, options.type);
+
+ ASSERT_NE(nullptr, output_toco_op.get());
+ }
+};
+
+TEST_F(OperatorTest, SimpleOperators) {
+ CheckSimpleOperator<DequantizeOperator>("DEQUANTIZE",
+ OperatorType::kDequantize);
+ CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor);
+ CheckSimpleOperator<GatherOperator>("GATHER", OperatorType::kGather);
+ CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
+ CheckSimpleOperator<Relu1Operator>("RELU1", OperatorType::kRelu1);
+ CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);
+ CheckSimpleOperator<ResizeBilinearOperator>("RESIZE_BILINEAR",
+ OperatorType::kResizeBilinear);
+ CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic);
+ CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh);
+}
+
+TEST_F(OperatorTest, BuiltinAdd) {
+ AddOperator op;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu6;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("ADD", OperatorType::kAdd), op);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+}
+
+TEST_F(OperatorTest, CustomCast) {
+ CastOperator op;
+ op.src_data_type = ArrayDataType::kFloat;
+ op.dst_data_type = ArrayDataType::kUint8;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("CAST", OperatorType::kCast), op);
+ EXPECT_EQ(op.src_data_type, output_toco_op->src_data_type);
+ EXPECT_EQ(op.dst_data_type, output_toco_op->dst_data_type);
+}
+
+TEST_F(OperatorTest, CustomConcatenation) {
+ ConcatenationOperator op;
+ op.concat_dim = 123;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("CONCATENATION", OperatorType::kConcatenation), op);
+ EXPECT_EQ(op.concat_dim, output_toco_op->concat_dim);
+}
+
+TEST_F(OperatorTest, CustomDepthToSpace) {
+ DepthToSpaceOperator op;
+ op.block_size = 123;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("DEPTH_TO_SPACE", OperatorType::kDepthToSpace), op);
+ EXPECT_EQ(op.block_size, output_toco_op->block_size);
+}
+
+TEST_F(OperatorTest, CustomFakeQuant) {
+ FakeQuantOperator op;
+ auto* minmax = new MinMax;
+ minmax->min = -10;
+ minmax->max = 200;
+ op.minmax.reset(minmax);
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("FAKE_QUANT", OperatorType::kFakeQuant), op);
+ EXPECT_EQ(op.minmax->min, output_toco_op->minmax->min);
+ EXPECT_EQ(op.minmax->max, output_toco_op->minmax->max);
+}
+
+TEST_F(OperatorTest, CustomFullyConnected) {
+ FullyConnectedOperator op;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu6;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("FULLY_CONNECTED", OperatorType::kFullyConnected), op);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+}
+
+TEST_F(OperatorTest, BuiltinL2Pool) {
+ L2PoolOperator op;
+ op.stride_width = 123;
+ op.stride_height = 124;
+ op.padding.type = PaddingType::kValid;
+ op.kwidth = 480;
+ op.kheight = 1080;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("L2_POOL_2D", OperatorType::kL2Pool), op);
+ EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
+ EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
+ EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
+ EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
+ EXPECT_EQ(op.kheight, output_toco_op->kheight);
+}
+
+TEST_F(OperatorTest, BuiltinLocalResponseNormalization) {
+ LocalResponseNormalizationOperator op;
+ op.range = 123;
+ op.bias = 1.23;
+ op.alpha = 12.3;
+ op.beta = .123;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("LOCAL_RESPONSE_NORMALIZATION",
+ OperatorType::kLocalResponseNormalization),
+ op);
+ EXPECT_EQ(op.range, output_toco_op->range);
+ EXPECT_EQ(op.bias, output_toco_op->bias);
+ EXPECT_EQ(op.alpha, output_toco_op->alpha);
+ EXPECT_EQ(op.beta, output_toco_op->beta);
+}
+
+TEST_F(OperatorTest, BuiltinMaxPool) {
+ MaxPoolOperator op;
+ op.stride_width = 123;
+ op.stride_height = 124;
+ op.padding.type = PaddingType::kValid;
+ op.kwidth = 480;
+ op.kheight = 1080;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("MAX_POOL_2D", OperatorType::kMaxPool), op);
+ EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
+ EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
+ EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
+ EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
+ EXPECT_EQ(op.kheight, output_toco_op->kheight);
+}
+
+TEST_F(OperatorTest, BuiltinReshape) {
+ TensorFlowReshapeOperator op;
+ op.shape = {1, 2, 4, 5, 8};
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("RESHAPE", OperatorType::kTensorFlowReshape), op);
+ EXPECT_EQ(op.shape, output_toco_op->shape);
+}
+
+TEST_F(OperatorTest, CustomSoftmax) {
+ SoftmaxOperator op;
+ op.beta = 123.1;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("SOFTMAX", OperatorType::kSoftmax), op);
+ EXPECT_EQ(op.beta, output_toco_op->beta);
+}
+
+TEST_F(OperatorTest, BuiltinSpaceToDepth) {
+ SpaceToDepthOperator op;
+ op.block_size = 123;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("SPACE_TO_DEPTH", OperatorType::kSpaceToDepth), op);
+ EXPECT_EQ(op.block_size, output_toco_op->block_size);
+}
+
+TEST_F(OperatorTest, CustomSplit) {
+ TensorFlowSplitOperator op;
+ op.num_split = 123;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("SPLIT", OperatorType::kTensorFlowSplit), op);
+ EXPECT_EQ(op.num_split, output_toco_op->num_split);
+}
+
+TEST_F(OperatorTest, BuiltinAveragePool) {
+ AveragePoolOperator op;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu6;
+ op.stride_width = 123;
+ op.stride_height = 124;
+ op.padding.type = PaddingType::kValid;
+ op.kwidth = 480;
+ op.kheight = 1080;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("AVERAGE_POOL_2D", OperatorType::kAveragePool), op);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+ EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
+ EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
+ EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
+ EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
+ EXPECT_EQ(op.kheight, output_toco_op->kheight);
+}
+
+TEST_F(OperatorTest, BuiltinConvolution) {
+ ConvOperator op;
+ op.stride_width = 123;
+ op.stride_height = 124;
+ op.padding.type = PaddingType::kValid;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu6;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("CONV_2D", OperatorType::kConv), op);
+ EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
+ EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
+ EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+}
+
+TEST_F(OperatorTest, BuiltinDepthwiseConvolution) {
+ DepthwiseConvOperator op;
+ op.stride_width = 123;
+ op.stride_height = 124;
+ op.padding.type = PaddingType::kValid;
+ op.depth_multiplier = 6;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu6;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("DEPTHWISE_CONV_2D", OperatorType::kDepthwiseConv), op);
+ EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
+ EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
+ EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
+ EXPECT_EQ(op.depth_multiplier, output_toco_op->depth_multiplier);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+}
+
+TEST_F(OperatorTest, BuiltinL2Norm) {
+ L2NormalizationOperator op;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu6;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("L2_NORMALIZATION", OperatorType::kL2Normalization), op);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+}
+
+TEST_F(OperatorTest, BuiltinMul) {
+ MulOperator op;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu6;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("MUL", OperatorType::kMul), op);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+}
+
+TEST_F(OperatorTest, Svdf) {
+ SvdfOperator op;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("SVDF", OperatorType::kSvdf), op);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+}
+
+TEST_F(OperatorTest, TensorFlowUnsupported) {
+ TensorFlowUnsupportedOperator op;
+ op.tensorflow_op = "MyCustomUnsupportedOp";
+
+ ::tensorflow::NodeDef node_def;
+ auto attr = node_def.mutable_attr();
+ (*attr)["float_attr"].set_f(2.0);
+ (*attr)["str_attr"].set_s("Hello World");
+ (*attr)["int_attr"].set_i(17);
+ (*attr)["bool_attr"].set_b(true);
+ node_def.SerializeToString(&op.tensorflow_node_def);
+
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED",
+ OperatorType::kTensorFlowUnsupported),
+ op);
+
+ ::tensorflow::NodeDef output_node_def;
+ output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
+ const auto& output_attr = output_node_def.attr();
+ EXPECT_EQ(2.0, output_attr.at("float_attr").f());
+ EXPECT_EQ("Hello World", output_attr.at("str_attr").s());
+ EXPECT_EQ(17, output_attr.at("int_attr").i());
+ EXPECT_EQ(true, output_attr.at("bool_attr").b());
+}
+
+TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
+ TensorFlowUnsupportedOperator op;
+ op.tensorflow_op = "MyCustomUnsupportedOp";
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED",
+ OperatorType::kTensorFlowUnsupported),
+ op);
+
+ ::tensorflow::NodeDef output_node_def;
+ output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
+ EXPECT_TRUE(output_node_def.attr().empty());
+}
+
+} // namespace
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/simple_operator.h b/tensorflow/contrib/lite/toco/tflite/simple_operator.h
new file mode 100644
index 0000000000..992b98baca
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/simple_operator.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_
+
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+
+namespace toco {
+
+namespace tflite {
+
+// Simple operators don't have any configuration options and can be trivially
+// serialized and deserialized. Note that most of toco's operators will
+// likely be supported as builtin operators in TF Lite. Simple (and custom)
+// operators are mostly a convenience for the times when tf.mini supports more
+// operators than TF Lite.
+//
+// Template argument T must derive from ::toco::Operator.
+template <typename T>
+class SimpleOperator : public BaseOperator {
+ public:
+ using BaseOperator::BaseOperator;
+ Options Serialize(const Operator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return Options();
+ }
+ std::unique_ptr<Operator> Deserialize(
+ const BuiltinOptions* builtin_options,
+ const CustomOptions* custom_options) const override {
+ return std::unique_ptr<Operator>(new T);
+ }
+};
+
+} // namespace tflite
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc
new file mode 100644
index 0000000000..5b4dbfae24
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/types.cc
@@ -0,0 +1,165 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/types.h"
+
+namespace toco {
+
+namespace tflite {
+
+namespace {
+template <ArrayDataType T>
+DataBuffer::FlatBufferOffset CopyBuffer(
+ const Array& array, flatbuffers::FlatBufferBuilder* builder) {
+ using NativeT = ::toco::DataType<T>;
+ const auto& src_data = array.GetBuffer<T>().data;
+ const uint8_t* dst_data = reinterpret_cast<const uint8_t*>(src_data.data());
+ auto size = src_data.size() * sizeof(NativeT);
+ return builder->CreateVector(dst_data, size);
+}
+
+template <ArrayDataType T>
+void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
+ using NativeT = ::toco::DataType<T>;
+ auto* src_buffer = buffer.data();
+ const NativeT* src_data =
+ reinterpret_cast<const NativeT*>(src_buffer->data());
+ int num_items = src_buffer->size() / sizeof(NativeT);
+
+ std::vector<NativeT>* dst_data = &array->GetMutableBuffer<T>().data;
+ for (int i = 0; i < num_items; ++i) {
+ dst_data->push_back(*src_data);
+ ++src_data;
+ }
+}
+} // namespace
+
+::tflite::TensorType DataType::Serialize(ArrayDataType array_data_type) {
+ switch (array_data_type) {
+ case ArrayDataType::kFloat:
+ return ::tflite::TensorType_FLOAT32;
+ case ArrayDataType::kInt32:
+ return ::tflite::TensorType_INT32;
+ case ArrayDataType::kUint8:
+ return ::tflite::TensorType_UINT8;
+ default:
+ // FLOAT32 is filled for unknown data types.
+ // TODO(ycling): Implement type inference in TF Lite interpreter.
+ return ::tflite::TensorType_FLOAT32;
+ }
+}
+
+ArrayDataType DataType::Deserialize(int tensor_type) {
+ switch (::tflite::TensorType(tensor_type)) {
+ case ::tflite::TensorType_FLOAT32:
+ return ArrayDataType::kFloat;
+ case ::tflite::TensorType_INT32:
+ return ArrayDataType::kInt32;
+ case ::tflite::TensorType_UINT8:
+ return ArrayDataType::kUint8;
+ default:
+ LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'.";
+ }
+}
+
+flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
+ const Array& array, flatbuffers::FlatBufferBuilder* builder) {
+ if (!array.buffer) return 0; // an empty buffer, usually an output.
+
+ switch (array.data_type) {
+ case ArrayDataType::kFloat:
+ return CopyBuffer<ArrayDataType::kFloat>(array, builder);
+ case ArrayDataType::kInt32:
+ return CopyBuffer<ArrayDataType::kInt32>(array, builder);
+ case ArrayDataType::kUint8:
+ return CopyBuffer<ArrayDataType::kUint8>(array, builder);
+ default:
+ LOG(FATAL) << "Unhandled array data type.";
+ }
+}
+
+void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
+ const ::tflite::Buffer& buffer, Array* array) {
+ if (tensor.buffer() == 0) return; // an empty buffer, usually an output.
+ if (buffer.data() == nullptr) return; // a non-defined buffer.
+
+ switch (tensor.type()) {
+ case ::tflite::TensorType_FLOAT32:
+ return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
+ case ::tflite::TensorType_INT32:
+ return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
+ case ::tflite::TensorType_UINT8:
+ return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
+ default:
+ LOG(FATAL) << "Unhandled tensor type.";
+ }
+}
+
+::tflite::Padding Padding::Serialize(PaddingType padding_type) {
+ switch (padding_type) {
+ case PaddingType::kSame:
+ return ::tflite::Padding_SAME;
+ case PaddingType::kValid:
+ return ::tflite::Padding_VALID;
+ default:
+ LOG(FATAL) << "Unhandled padding type.";
+ }
+}
+
+PaddingType Padding::Deserialize(int padding) {
+ switch (::tflite::Padding(padding)) {
+ case ::tflite::Padding_SAME:
+ return PaddingType::kSame;
+ case ::tflite::Padding_VALID:
+ return PaddingType::kValid;
+ default:
+ LOG(FATAL) << "Unhandled padding.";
+ }
+}
+
+::tflite::ActivationFunctionType ActivationFunction::Serialize(
+ FusedActivationFunctionType faf_type) {
+ switch (faf_type) {
+ case FusedActivationFunctionType::kNone:
+ return ::tflite::ActivationFunctionType_NONE;
+ case FusedActivationFunctionType::kRelu:
+ return ::tflite::ActivationFunctionType_RELU;
+ case FusedActivationFunctionType::kRelu6:
+ return ::tflite::ActivationFunctionType_RELU6;
+ case FusedActivationFunctionType::kRelu1:
+ return ::tflite::ActivationFunctionType_RELU1;
+ default:
+ LOG(FATAL) << "Unhandled fused activation function type.";
+ }
+}
+
+FusedActivationFunctionType ActivationFunction::Deserialize(
+ int activation_function) {
+ switch (::tflite::ActivationFunctionType(activation_function)) {
+ case ::tflite::ActivationFunctionType_NONE:
+ return FusedActivationFunctionType::kNone;
+ case ::tflite::ActivationFunctionType_RELU:
+ return FusedActivationFunctionType::kRelu;
+ case ::tflite::ActivationFunctionType_RELU6:
+ return FusedActivationFunctionType::kRelu6;
+ case ::tflite::ActivationFunctionType_RELU1:
+ return FusedActivationFunctionType::kRelu1;
+ default:
+ LOG(FATAL) << "Unhandled fused activation function type.";
+ }
+}
+
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/types.h b/tensorflow/contrib/lite/toco/tflite/types.h
new file mode 100644
index 0000000000..f7c5140510
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/types.h
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_
+
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+namespace tflite {
+
+struct DataType {
+ static ::tflite::TensorType Serialize(ArrayDataType array_data_type);
+ static ArrayDataType Deserialize(int tensor_type);
+};
+
+struct DataBuffer {
+ using FlatBufferOffset = flatbuffers::Offset<flatbuffers::Vector<uint8_t>>;
+
+ // Build the flatbuffer representation of a toco's Array and return the
+ // corresponding offset into the flatbuffer. Note that data from the array
+ // will be copied into the flatbuffer.
+ static FlatBufferOffset Serialize(const Array& array,
+ flatbuffers::FlatBufferBuilder* builder);
+ // Copy data from the given tensor into toco's Array.
+ static void Deserialize(const ::tflite::Tensor& tensor,
+ const ::tflite::Buffer& buffer, Array* array);
+};
+
+struct Padding {
+ static ::tflite::Padding Serialize(PaddingType padding_type);
+ static PaddingType Deserialize(int padding);
+};
+
+struct ActivationFunction {
+ static ::tflite::ActivationFunctionType Serialize(
+ FusedActivationFunctionType faf_type);
+ static FusedActivationFunctionType Deserialize(int activation_function);
+};
+
+} // namespace tflite
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc
new file mode 100644
index 0000000000..174b78f3e6
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc
@@ -0,0 +1,191 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/types.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace toco {
+
+namespace tflite {
+namespace {
+
+using flatbuffers::FlatBufferBuilder;
+using flatbuffers::Offset;
+using flatbuffers::Vector;
+
+// These are types that exist in TF Mini but don't have a correspondence
+// in TF Lite.
+static const ArrayDataType kUnsupportedTocoTypes[] = {
+ ArrayDataType::kNone, ArrayDataType::kBool, ArrayDataType::kInt64};
+
+// These are TF Lite types for which there is no correspondence in TF Mini.
+static const ::tflite::TensorType kUnsupportedTfLiteTypes[] = {
+ ::tflite::TensorType_FLOAT16};
+
+// A little helper to match flatbuffer offsets.
+MATCHER_P(HasOffset, value, "") { return arg.o == value; }
+
+// Helper function that creates an array, writes it into a flatbuffer, and then
+// reads it back in.
+template <ArrayDataType T>
+Array ToFlatBufferAndBack(std::initializer_list<::toco::DataType<T>> items) {
+ // NOTE: This test does not construct the full buffers list. Since
+ // Deserialize normally takes a buffer, we need to synthesize one and provide
+ // an index that is non-zero so the buffer is not assumed to be emtpy.
+ Array src;
+ src.data_type = T;
+ src.GetMutableBuffer<T>().data = items;
+
+ Array result;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateTensor(builder, 0, DataType::Serialize(T),
+ /*buffer*/ 1)); // Can't use 0 which means empty.
+ flatbuffers::FlatBufferBuilder buffer_builder;
+ Offset<Vector<uint8_t>> data_buffer =
+ DataBuffer::Serialize(src, &buffer_builder);
+ buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, data_buffer));
+
+ auto* tensor =
+ flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer());
+ auto* buffer =
+ flatbuffers::GetRoot<::tflite::Buffer>(buffer_builder.GetBufferPointer());
+ DataBuffer::Deserialize(*tensor, *buffer, &result);
+ return result;
+}
+
+TEST(DataType, SupportedTypes) {
+ std::vector<std::pair<ArrayDataType, ::tflite::TensorType>> testdata = {
+ {ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
+ {ArrayDataType::kInt32, ::tflite::TensorType_INT32},
+ {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}};
+ for (auto x : testdata) {
+ EXPECT_EQ(x.second, DataType::Serialize(x.first));
+ EXPECT_EQ(x.first, DataType::Deserialize(x.second));
+ }
+}
+
+TEST(DataType, UnsupportedTypes) {
+ for (::tflite::TensorType t : kUnsupportedTfLiteTypes) {
+ EXPECT_DEATH(DataType::Deserialize(t), "Unhandled tensor type.");
+ }
+
+ // Unsupported types are all serialized as FLOAT32 currently.
+ for (ArrayDataType t : kUnsupportedTocoTypes) {
+ EXPECT_EQ(::tflite::TensorType_FLOAT32, DataType::Serialize(t));
+ }
+}
+
+TEST(DataBuffer, EmptyBuffers) {
+ flatbuffers::FlatBufferBuilder builder;
+ Array array;
+ EXPECT_THAT(DataBuffer::Serialize(array, &builder), HasOffset(0));
+
+ builder.Finish(::tflite::CreateTensor(builder));
+ auto* tensor =
+ flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer());
+ flatbuffers::FlatBufferBuilder buffer_builder;
+ Offset<Vector<uint8_t>> v = buffer_builder.CreateVector<uint8_t>({});
+ buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, v));
+ auto* buffer =
+ flatbuffers::GetRoot<::tflite::Buffer>(buffer_builder.GetBufferPointer());
+
+ DataBuffer::Deserialize(*tensor, *buffer, &array);
+ EXPECT_EQ(nullptr, array.buffer);
+}
+
+TEST(DataBuffer, UnsupportedTypes) {
+ for (ArrayDataType t : kUnsupportedTocoTypes) {
+ flatbuffers::FlatBufferBuilder builder;
+ Array array;
+ array.data_type = t;
+ array.GetMutableBuffer<ArrayDataType::kFloat>(); // This is OK.
+ EXPECT_DEATH(DataBuffer::Serialize(array, &builder),
+ "Unhandled array data type.");
+ }
+
+ for (::tflite::TensorType t : kUnsupportedTfLiteTypes) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(::tflite::CreateTensor(builder, 0, t, /*buffer*/ 1));
+ flatbuffers::FlatBufferBuilder buffer_builder;
+ Offset<Vector<uint8_t>> v = buffer_builder.CreateVector<uint8_t>({1});
+ buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, v));
+ auto* buffer = flatbuffers::GetRoot<::tflite::Buffer>(
+ buffer_builder.GetBufferPointer());
+ auto* tensor =
+ flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer());
+ Array array;
+ EXPECT_DEATH(DataBuffer::Deserialize(*tensor, *buffer, &array),
+ "Unhandled tensor type.");
+ }
+}
+
+TEST(DataBuffer, Float) {
+ Array recovered = ToFlatBufferAndBack<ArrayDataType::kFloat>({1.0f, 2.0f});
+ EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kFloat>().data,
+ ::testing::ElementsAre(1.0f, 2.0f));
+}
+
+TEST(DataBuffer, Uint8) {
+ Array recovered = ToFlatBufferAndBack<ArrayDataType::kUint8>({127, 244});
+ EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kUint8>().data,
+ ::testing::ElementsAre(127, 244));
+}
+
+TEST(DataBuffer, Int32) {
+ Array recovered = ToFlatBufferAndBack<ArrayDataType::kInt32>({1, 1 << 30});
+ EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kInt32>().data,
+ ::testing::ElementsAre(1, 1 << 30));
+}
+
+TEST(Padding, All) {
+ EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame));
+ EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME));
+
+ EXPECT_EQ(::tflite::Padding_VALID, Padding::Serialize(PaddingType::kValid));
+ EXPECT_EQ(PaddingType::kValid, Padding::Deserialize(::tflite::Padding_VALID));
+
+ EXPECT_DEATH(Padding::Serialize(static_cast<PaddingType>(10000)),
+ "Unhandled padding type.");
+ EXPECT_DEATH(Padding::Deserialize(10000), "Unhandled padding.");
+}
+
+TEST(ActivationFunction, All) {
+ std::vector<
+ std::pair<FusedActivationFunctionType, ::tflite::ActivationFunctionType>>
+ testdata = {{FusedActivationFunctionType::kNone,
+ ::tflite::ActivationFunctionType_NONE},
+ {FusedActivationFunctionType::kRelu,
+ ::tflite::ActivationFunctionType_RELU},
+ {FusedActivationFunctionType::kRelu6,
+ ::tflite::ActivationFunctionType_RELU6},
+ {FusedActivationFunctionType::kRelu1,
+ ::tflite::ActivationFunctionType_RELU1}};
+ for (auto x : testdata) {
+ EXPECT_EQ(x.second, ActivationFunction::Serialize(x.first));
+ EXPECT_EQ(x.first, ActivationFunction::Deserialize(x.second));
+ }
+
+ EXPECT_DEATH(ActivationFunction::Serialize(
+ static_cast<FusedActivationFunctionType>(10000)),
+ "Unhandled fused activation function type.");
+ EXPECT_DEATH(ActivationFunction::Deserialize(10000),
+ "Unhandled fused activation function type.");
+}
+
+} // namespace
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco.cc b/tensorflow/contrib/lite/toco/toco.cc
new file mode 100644
index 0000000000..f01ec0ec61
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco.cc
@@ -0,0 +1,119 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdio>
+#include <memory>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
+#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_tooling.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/core/platform/logging.h"
+
+#ifndef CHECK_OK
+#define CHECK_OK(val) CHECK_EQ((val).ok(), true)
+#define QCHECK_OK(val) QCHECK_EQ((val).ok(), true)
+#endif
+
+namespace toco {
+namespace {
+
+#define QCHECK_REQUIRE_TOCO_FLAG(arg) \
+ QCHECK(parsed_toco_flags.arg.specified()) << "Missing required flag: " #arg;
+
+void CheckFilePermissions(const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags,
+ const TocoFlags& toco_flags) {
+ port::CheckInitGoogleIsDone("InitGoogle is not done yet");
+
+ QCHECK_REQUIRE_TOCO_FLAG(input_file)
+ QCHECK_OK(port::file::Exists(parsed_toco_flags.input_file.value(),
+ port::file::Defaults()))
+ << "Specified input_file does not exist: "
+ << parsed_toco_flags.input_file.value();
+ QCHECK_OK(port::file::Readable(parsed_toco_flags.input_file.value(),
+ port::file::Defaults()))
+ << "Specified input_file exists, but is not readable: "
+ << parsed_toco_flags.input_file.value();
+
+ QCHECK_REQUIRE_TOCO_FLAG(output_file);
+ QCHECK_OK(port::file::Writable(parsed_toco_flags.output_file.value()))
+ << "parsed_toco_flags.input_file.value() output_file is not writable: "
+ << parsed_toco_flags.output_file.value();
+}
+
+void ToolMain(const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags) {
+ ModelFlags model_flags;
+ ReadModelFlagsFromCommandLineFlags(parsed_model_flags, &model_flags);
+
+ TocoFlags toco_flags;
+ ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags);
+
+ CheckFilePermissions(parsed_toco_flags, parsed_model_flags, toco_flags);
+
+ string input_file_contents;
+ CHECK_OK(port::file::GetContents(parsed_toco_flags.input_file.value(),
+ &input_file_contents,
+ port::file::Defaults()));
+ std::unique_ptr<Model> model =
+ Import(toco_flags, model_flags, input_file_contents);
+ Transform(toco_flags, model.get());
+ string output_file_contents;
+ Export(toco_flags, *model, toco_flags.allow_custom_ops(),
+ &output_file_contents);
+ CHECK_OK(port::file::SetContents(parsed_toco_flags.output_file.value(),
+ output_file_contents,
+ port::file::Defaults()));
+}
+
+} // namespace
+} // namespace toco
+
+int main(int argc, char** argv) {
+ toco::string msg;
+ toco::ParsedTocoFlags parsed_toco_flags;
+ toco::ParsedModelFlags parsed_model_flags;
+
+ // If no args were specified, give a help string to be helpful.
+ int* effective_argc = &argc;
+ char** effective_argv = argv;
+ if (argc == 1) {
+ // No arguments, so manufacture help argv.
+ static int dummy_argc = 2;
+ static char* dummy_argv[] = {argv[0], const_cast<char*>("--help")};
+ effective_argc = &dummy_argc;
+ effective_argv = dummy_argv;
+ }
+
+ // Parse toco flags and command flags in sequence, each one strips off args,
+ // giving InitGoogle a chance to handle all remaining arguments.
+ bool toco_success = toco::ParseTocoFlagsFromCommandLineFlags(
+ effective_argc, effective_argv, &msg, &parsed_toco_flags);
+ bool model_success = toco::ParseModelFlagsFromCommandLineFlags(
+ effective_argc, effective_argv, &msg, &parsed_model_flags);
+ if (!toco_success || !model_success || !msg.empty()) {
+ fprintf(stderr, "%s", msg.c_str());
+ fflush(stderr);
+ return 1;
+ }
+ toco::port::InitGoogle(argv[0], effective_argc, &effective_argv, true);
+ toco::ToolMain(parsed_toco_flags, parsed_model_flags);
+}
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
new file mode 100644
index 0000000000..d43c3b4a8e
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -0,0 +1,206 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/strip.h"
+#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace toco {
+
+bool ParseTocoFlagsFromCommandLineFlags(
+ int* argc, char* argv[], string* msg,
+ ParsedTocoFlags* parsed_toco_flags_ptr) {
+ using tensorflow::Flag;
+ ParsedTocoFlags& parsed_flags = *parsed_toco_flags_ptr;
+ std::vector<tensorflow::Flag> flags = {
+ Flag("input_file", parsed_flags.input_file.bind(),
+ parsed_flags.input_file.default_value(),
+ "Input file (model of any supported format). For Protobuf "
+ "formats, both text and binary are supported regardless of file "
+ "extension."),
+ Flag("output_file", parsed_flags.output_file.bind(),
+ parsed_flags.output_file.default_value(),
+ "Output file. "
+ "For Protobuf formats, the binary format will be used."),
+ Flag("input_format", parsed_flags.input_format.bind(),
+ parsed_flags.input_format.default_value(),
+ "Input file format. One of: tensorflow_graphdef, "),
+ Flag("output_format", parsed_flags.output_format.bind(),
+ parsed_flags.output_format.default_value(), "Output file format."),
+ Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(),
+ parsed_flags.default_ranges_min.default_value(),
+ "If defined, will be used as the default value for the min bound "
+ "of min/max ranges used for quantization."),
+ Flag("default_ranges_max", parsed_flags.default_ranges_max.bind(),
+ parsed_flags.default_ranges_max.default_value(),
+ "If defined, will be used as the default value for the max bound "
+ "of min/max ranges used for quantization."),
+ Flag("input_type", parsed_flags.input_type.bind(),
+ parsed_flags.input_type.default_value(),
+ "Data type of the input array in the "
+ "output file. "),
+ Flag("input_types", parsed_flags.input_types.bind(),
+ parsed_flags.input_types.default_value(),
+ "Data types of the input arrays in the "
+ "output file. "
+ "Comma-separated list matching the enumeration order of "
+ "input_arrays."),
+ Flag("inference_type", parsed_flags.inference_type.bind(),
+ parsed_flags.inference_type.default_value(),
+ "Data type, in the output file, of internal and output arrays "
+ "that are FLOAT in the input file. Thus, the value FLOAT means "
+ "keep doing floating-point inference, while the value "
+ "QUANTIZED_UINT8 means replace all internal floating-point "
+ "arithmetic by integer arithmetic producing 8-bit integer "
+ "activations instead of float activations --- which we call "
+ "\'quantized inference\'."),
+ Flag("drop_fake_quant", parsed_flags.drop_fake_quant.bind(),
+ parsed_flags.drop_fake_quant.default_value(),
+ "Ignore and discard FakeQuant nodes. For instance, that can be used "
+ "to "
+ "generate plain float code without fake-quantization from a "
+ "quantized "
+ "graph."),
+ Flag(
+ "reorder_across_fake_quant",
+ parsed_flags.reorder_across_fake_quant.bind(),
+ parsed_flags.reorder_across_fake_quant.default_value(),
+ "Normally, FakeQuant nodes must be strict boundaries for graph "
+ "transformations, in order to ensure that quantized inference has "
+ "the "
+ "exact same arithmetic behavior as quantized training --- which is "
+ "the "
+ "whole point of quantized training and of FakeQuant nodes in the "
+ "first "
+ "place. However, that entails subtle requirements on where exactly "
+ "FakeQuant nodes must be placed in the graph. Some quantized graphs "
+ "have FakeQuant nodes at unexpected locations, that prevent graph "
+ "transformations that are necessary in order to generate inference "
+ "code for these graphs. Such graphs should be fixed, but as a "
+ "temporary work-around, setting this reorder_across_fake_quant flag "
+ "allows toco to perform necessary graph transformaitons on them, "
+ "at the cost of no longer faithfully matching inference and training "
+ "arithmetic."),
+ Flag("allow_custom_ops", parsed_flags.allow_custom_ops.bind(),
+ parsed_flags.allow_custom_ops.default_value(),
+ "If true, allow TOCO to create TF Lite Custom operators for all the"
+ "unsupported Tensorflow ops."),
+ };
+ bool asked_for_help =
+ *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
+ if (asked_for_help) {
+ *msg += tensorflow::Flags::Usage(argv[0], flags);
+ return false;
+ } else {
+ return tensorflow::Flags::Parse(argc, argv, flags);
+ }
+}
+
+void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
+ TocoFlags* toco_flags) {
+ namespace port = toco::port;
+ port::CheckInitGoogleIsDone("InitGoogle is not done yet");
+
+ enum class FlagRequirement { kNone, kMustBeSpecified, kMustNotBeSpecified };
+
+#define ENFORCE_FLAG_REQUIREMENT(name, requirement) \
+ do { \
+ if (requirement == FlagRequirement::kMustBeSpecified) { \
+ QCHECK(parsed_toco_flags.name.specified()) \
+ << "Missing required flag: " << #name; \
+ } \
+ if (requirement == FlagRequirement::kMustNotBeSpecified) { \
+ QCHECK(!parsed_toco_flags.name.specified()) \
+ << "Given other flags, this flag should not have been specified: " \
+ << #name; \
+ } \
+ } while (false)
+
+#define READ_TOCO_FLAG(name, requirement) \
+ ENFORCE_FLAG_REQUIREMENT(name, requirement); \
+ do { \
+ if (parsed_toco_flags.name.specified()) { \
+ toco_flags->set_##name(parsed_toco_flags.name.value()); \
+ } \
+ } while (false)
+
+#define PARSE_TOCO_FLAG(Type, name, requirement) \
+ ENFORCE_FLAG_REQUIREMENT(name, requirement); \
+ do { \
+ if (parsed_toco_flags.name.specified()) { \
+ Type x; \
+ QCHECK(Type##_Parse(parsed_toco_flags.name.value(), &x)) \
+ << "Unrecognized " << #Type << " value " \
+ << parsed_toco_flags.name.value(); \
+ toco_flags->set_##name(x); \
+ } \
+ } while (false)
+
+ PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kMustBeSpecified);
+ PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kMustBeSpecified);
+ FlagRequirement tflite_flags_requirement =
+ toco_flags->output_format() == TFLITE
+ ? FlagRequirement::kMustBeSpecified
+ : FlagRequirement::kMustNotBeSpecified;
+ PARSE_TOCO_FLAG(IODataType, inference_type, tflite_flags_requirement);
+ READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone);
+ READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone);
+ READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone);
+ READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
+ READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
+
+#undef READ_TOCO_FLAG
+#undef PARSE_TOCO_FLAG
+
+ const bool input_type_specified = parsed_toco_flags.input_type.specified();
+ const bool input_types_specified = parsed_toco_flags.input_types.specified();
+ if (toco_flags->output_format() == TFLITE) {
+ QCHECK(input_type_specified || input_types_specified)
+ << "When output_format=TFLITE, either input_type or input_types needs "
+ "to be specified.";
+ } else {
+ QCHECK(!input_type_specified && !input_types_specified)
+ << "With this output_format, neither input_type nor input_types must "
+ "be specified.";
+ }
+ QCHECK(!(input_type_specified && input_types_specified))
+ << "input_type and input_types are mutually exclusive";
+ if (input_type_specified) {
+ IODataType type;
+ QCHECK(IODataType_Parse(parsed_toco_flags.input_type.value(), &type))
+ << "Unrecognized input_type: " << parsed_toco_flags.input_type.value();
+ toco_flags->add_input_types(type);
+ }
+ if (input_types_specified) {
+ std::vector<string> input_types =
+ absl::StrSplit(parsed_toco_flags.input_types.value(), ',');
+ for (const string& t : input_types) {
+ IODataType type;
+ QCHECK(IODataType_Parse(t, &type))
+ << "Unrecognized input_types value " << t
+ << " in input_types=" << parsed_toco_flags.input_types.value();
+ toco_flags->add_input_types(type);
+ }
+ }
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.h b/tensorflow/contrib/lite/toco/toco_cmdline_flags.h
new file mode 100644
index 0000000000..155a6fea87
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.h
@@ -0,0 +1,35 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_
+
+#include <string>
+#include <vector>
+#include "tensorflow/contrib/lite/toco/args.h"
+#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+
+namespace toco {
+// Parse and remove arguments handled from toco. Returns true if parsing
+// is successful. msg has the usage string if there was an error or
+// "--help" was specified
+bool ParseTocoFlagsFromCommandLineFlags(int* argc, char* argv[], string* msg,
+ ParsedTocoFlags* parsed_toco_flags_ptr);
+// Populate the TocoFlags proto with parsed_toco_flags data.
+void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
+ TocoFlags* toco_flags);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
new file mode 100644
index 0000000000..fd7c29fdc7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -0,0 +1,126 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+syntax = "proto2";
+package toco;
+
+// Supported I/O file formats. Some formats may be input-only or output-only.
+enum FileFormat {
+ FILE_FORMAT_UNKNOWN = 0;
+
+ // GraphDef, third_party/tensorflow/core/framework/graph.proto
+ TENSORFLOW_GRAPHDEF = 1;
+
+ // Tensorflow's mobile inference model.
+ // third_party/tensorflow/contrib/tflite/schema.fbs
+ TFLITE = 2;
+
+ // GraphViz
+ // Export-only.
+ GRAPHVIZ_DOT = 3;
+}
+
+// IODataType describes the numeric data types to be used by the output format.
+// See input_type and inference_type below.
+enum IODataType {
+ IO_DATA_TYPE_UNKNOWN = 0;
+
+ // Float32, not quantized
+ FLOAT = 1;
+
+ // Uint8, quantized
+ QUANTIZED_UINT8 = 2;
+
+ // Int32, not quantized
+ INT32 = 3;
+
+ // Int64, not quantized
+ INT64 = 4;
+
+ // String, not quantized
+ STRING = 5;
+}
+
+// TocoFlags encodes extra parameters that drive tooling operations, that
+// are not normally encoded in model files and in general may not be thought
+// of as properties of models, instead describing how models are to be
+// processed in the context of the present tooling job.
+// Next Id: 11
+message TocoFlags {
+ // Input file format
+ optional FileFormat input_format = 1;
+
+ // Output file format
+ optional FileFormat output_format = 2;
+
+ // Numeric data types of the input arrays in the output format.
+ // This controls what input types the output file will be expecting.
+ // This is not a description of the input types of the input file.
+ // For example, the input file may have a float input placeholder,
+ // but we may want to generate a quantized TFLite file from it,
+ // or a float TFLite file taking a quantized input.
+ //
+ // The length of this list should match the length of the input_arrays
+ // list in ModelFlags.
+ repeated IODataType input_types = 9;
+
+ // Numeric data type of the internal activations array and output array.
+ //
+ // As a matter of implementation detail, most model
+ // parameter arrays (weights, etc) will tend to also use this data type.
+ // Not all will, though: for instance, bias vectors will typically
+ // get quantized as int32 when weights and activations get quantized as uint8.
+ optional IODataType inference_type = 4;
+
+ // default_ranges_min and default_ranges_max are helpers to experiment
+ // with quantization of models. Normally, quantization requires the input
+ // model to have (min, max) range information for every activations array.
+ // This is needed in order to know how to quantize arrays and still achieve
+ // satisfactory accuracy. However, in some circumstances one would just like
+ // to estimate the performance of quantized inference, without caring about
+ // accuracy. That is what default_ranges_min and default_ranges_max are for:
+ // when specified, they will be used as default (min, max) range boundaries
+ // for all activation arrays that lack (min, max) range information, thus
+ // allowing for quantization to proceed.
+ //
+ // It should be clear from the above explanation that these parameters are
+ // for experimentation purposes only and should not be used in production:
+ // they make it easy to quantize models, but the resulting quantized model
+ // will be inaccurate.
+ optional float default_ranges_min = 5;
+ optional float default_ranges_max = 6;
+
+ // Ignore and discard FakeQuant nodes. For instance, that can be used to
+ // generate plain float code without fake-quantization from a quantized
+ // graph.
+ optional bool drop_fake_quant = 7;
+
+ // Normally, FakeQuant nodes must be strict boundaries for graph
+ // transformations, in order to ensure that quantized inference has the
+ // exact same arithmetic behavior as quantized training --- which is the
+ // whole point of quantized training and of FakeQuant nodes in the first
+ // place. However, that entails subtle requirements on where exactly
+ // FakeQuant nodes must be placed in the graph. Some quantized graphs
+ // have FakeQuant nodes at unexpected locations, that prevent graph
+ // transformations that are necessary in order to generate inference
+ // code for these graphs. Such graphs should be fixed, but as a
+ // temporary work-around, setting this reorder_across_fake_quant flag
+ // allows toco to perform necessary graph transformaitons on them,
+ // at the cost of no longer faithfully matching inference and training
+ // arithmetic.
+ optional bool reorder_across_fake_quant = 8;
+
+ // If true, allow TOCO to create TF Lite Custom operators for all the
+ // unsupported Tensorflow ops.
+ optional bool allow_custom_ops = 10;
+}
diff --git a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc
new file mode 100644
index 0000000000..4e98e7081d
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc
@@ -0,0 +1,22 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
+
+namespace toco {
+GraphVizDumpOptions* GraphVizDumpOptions::singleton() {
+ static auto* ptr = new GraphVizDumpOptions;
+ return ptr;
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h
new file mode 100644
index 0000000000..ae0541f62b
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_
+
+#include <string>
+
+namespace toco {
+
+// Global data for determining whether to output graph viz format from toco.
+struct GraphVizDumpOptions {
+ std::string graphviz_first_array;
+ std::string graphviz_last_array;
+ std::string dump_graphviz;
+ bool dump_graphviz_video = false;
+
+ static GraphVizDumpOptions* singleton();
+};
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_
diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc
new file mode 100644
index 0000000000..a1c8696cd0
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_port.cc
@@ -0,0 +1,227 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstring>
+
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+namespace port {
+void CopyToBuffer(const string& src, char* dest) {
+ memcpy(dest, src.data(), src.size());
+}
+
+#ifdef PLATFORM_GOOGLE
+void CopyToBuffer(const Cord& src, char* dest) { src.CopyToArray(dest); }
+#endif
+} // namespace port
+} // namespace toco
+
+#if defined(PLATFORM_GOOGLE) && !defined(__APPLE__) && !defined(__ANDROID__)
+
+// Wrap Google file operations.
+
+#include "base/init_google.h"
+#include "file/base/file.h"
+#include "file/base/filesystem.h"
+#include "file/base/helpers.h"
+#include "file/base/options.h"
+#include "file/base/path.h"
+
+namespace toco {
+namespace port {
+
+void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags) {
+ ::InitGoogle(usage, argc, argv, remove_flags);
+}
+
+void CheckInitGoogleIsDone(const char* message) {
+ ::CheckInitGoogleIsDone(message);
+}
+
+namespace file {
+
+// Conversion to our wrapper Status.
+Status ToStatus(const ::util::Status& uts) {
+ return Status(uts.ok(), uts.error_message());
+}
+
+// Conversion to our wrapper Options.
+toco::port::file::Options ToOptions(const ::file::Options& options) {
+ CHECK_EQ(&options, &::file::Defaults());
+ return Options();
+}
+
+Status Writable(const string& filename) {
+ File* f = nullptr;
+ const auto status = ::file::Open(filename, "w", &f, ::file::Defaults());
+ if (f) {
+ QCHECK_OK(f->Close(::file::Defaults()));
+ }
+ return ToStatus(status);
+}
+
+Status Readable(const string& filename, const file::Options& options) {
+ return ToStatus(::file::Readable(filename, ::file::Defaults()));
+}
+
+Status Exists(const string& filename, const file::Options& options) {
+ auto status = ::file::Exists(filename, ::file::Defaults());
+ return ToStatus(status);
+}
+
+Status GetContents(const string& filename, string* contents,
+ const file::Options& options) {
+ return ToStatus(::file::GetContents(filename, contents, ::file::Defaults()));
+}
+
+Status SetContents(const string& filename, const string& contents,
+ const file::Options& options) {
+ return ToStatus(::file::SetContents(filename, contents, ::file::Defaults()));
+}
+
+string JoinPath(const string& a, const string& b) {
+ return ::file::JoinPath(a, b);
+}
+
+} // namespace file
+} // namespace port
+} // namespace toco
+
+#else // (__APPLE__ || __ANDROID__)
+
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <cstdio>
+
+#if defined(PLATFORM_GOOGLE)
+#include "base/commandlineflags.h"
+#endif
+
+namespace toco {
+namespace port {
+
+static bool port_initialized = false;
+
+void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags) {
+ if (!port_initialized) {
+#if defined(PLATFORM_GOOGLE)
+ ParseCommandLineFlags(argc, argv, remove_flags);
+#endif
+ port_initialized = true;
+ }
+}
+
+void CheckInitGoogleIsDone(const char* message) {
+ CHECK(port_initialized) << message;
+}
+
+namespace file {
+
+Status Writable(const string& filename) {
+ FILE* f = fopen(filename.c_str(), "w");
+ if (f) {
+ fclose(f);
+ return Status(true, "");
+ }
+ return Status(false, "not writable");
+}
+
+Status Readable(const string& filename, const file::Options& options) {
+ FILE* f = fopen(filename.c_str(), "r");
+ if (f) {
+ fclose(f);
+ return Status(true, "");
+ }
+ return Status(false, "not readable");
+}
+
+Status Exists(const string& filename, const file::Options& options) {
+ struct stat statbuf;
+ int ret = stat(filename.c_str(), &statbuf);
+ return Status(ret != -1, "");
+}
+
+Status GetContents(const string& path, string* output,
+ const file::Options& options) {
+ output->clear();
+
+ int fd = open(path.c_str(), O_RDONLY);
+ if (fd == -1) {
+ return Status(false, "can't open() for read");
+ }
+
+ // Direct read, for speed.
+ const int kBufSize = 1 << 16;
+ char buffer[kBufSize];
+ while (true) {
+ int size = read(fd, buffer, kBufSize);
+ if (size == 0) {
+ // Done.
+ close(fd);
+ return Status(true, "");
+ } else if (size == -1) {
+ // Error.
+ close(fd);
+ return Status(false, "error during read()");
+ } else {
+ output->append(buffer, size);
+ }
+ }
+
+ CHECK(0);
+ return Status(false, "internal error");
+}
+
+Status SetContents(const string& filename, const string& contents,
+ const file::Options& options) {
+ int fd = open(filename.c_str(), O_WRONLY | O_CREAT, 0664);
+ if (fd == -1) {
+ return Status(false, "can't open() for write");
+ }
+
+ size_t i = 0;
+ while (i < contents.size()) {
+ size_t to_write = contents.size() - i;
+ ssize_t written = write(fd, &contents[i], to_write);
+ if (written == -1) {
+ close(fd);
+ return Status(false, "write() error");
+ }
+ i += written;
+ }
+ close(fd);
+
+ return Status(true, "");
+}
+
+string JoinPath(const string& base, const string& filename) {
+ if (base.empty()) return filename;
+ string base_fixed = base;
+ if (!base_fixed.empty() && base_fixed.back() == '/') base_fixed.pop_back();
+ string filename_fixed = filename;
+ if (!filename_fixed.empty() && filename_fixed.front() == '/')
+ filename_fixed.erase(0, 1);
+ return base_fixed + "/" + filename_fixed;
+}
+
+} // namespace file
+} // namespace port
+} // namespace toco
+
+#endif // (__APPLE || __ANDROID__)
diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h
new file mode 100644
index 0000000000..b5cb7a11e7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_port.h
@@ -0,0 +1,80 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
+
+// Portability layer for toco tool. Mainly, abstract filesystem access so we
+// can build and use on google internal environments and on OSX.
+
+#include <string>
+#include "tensorflow/contrib/lite/toco/format_port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/platform.h"
+#if defined(PLATFORM_GOOGLE)
+#include "absl/strings/cord.h"
+#endif // PLATFORM_GOOGLE
+
+#ifdef PLATFORM_GOOGLE
+#define TFLITE_PROTO_NS proto2
+#else
+#define TFLITE_PROTO_NS google::protobuf
+#endif
+
+namespace toco {
+namespace port {
+
+class Status {
+ public:
+ Status() {}
+
+ Status(bool ok, const string& message) : ok_(ok), message_(message) {}
+
+ bool ok() const { return ok_; }
+
+ const string error_message() const { return message_; }
+
+ private:
+ bool ok_ = false;
+ string message_;
+};
+
+void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags);
+void CheckInitGoogleIsDone(const char* message);
+
+namespace file {
+class Options {};
+inline Options Defaults() {
+ Options o;
+ return o;
+}
+Status GetContents(const string& filename, string* contents,
+ const Options& options);
+Status SetContents(const string& filename, const string& contents,
+ const Options& options);
+string JoinPath(const string& base, const string& filename);
+Status Writable(const string& filename);
+Status Readable(const string& filename, const Options& options);
+Status Exists(const string& filename, const Options& options);
+} // namespace file
+
+// Copy `src` string to `dest`. User must ensure `dest` has enough space.
+#if defined(PLATFORM_GOOGLE)
+void CopyToBuffer(const ::Cord& src, char* dest);
+#endif // PLATFORM_GOOGLE
+void CopyToBuffer(const string& src, char* dest);
+} // namespace port
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
diff --git a/tensorflow/contrib/lite/toco/toco_port_test.cc b/tensorflow/contrib/lite/toco/toco_port_test.cc
new file mode 100644
index 0000000000..650a617aeb
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_port_test.cc
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace toco {
+namespace port {
+namespace {
+
+#ifdef PLATFORM_GOOGLE
+#define TFLITE_PREFIX "third_party/tensorflow/contrib/lite/"
+#else
+#define TFLITE_PREFIX "tensorflow/contrib/lite/"
+#endif
+
+TEST(TocoPortTest, Exists) {
+ EXPECT_TRUE(
+ file::Exists(TFLITE_PREFIX "toco/toco_port_test.cc", file::Defaults())
+ .ok());
+
+ EXPECT_FALSE(
+ file::Exists("non-existent_file_asldjflasdjf", file::Defaults()).ok());
+}
+
+TEST(TocoPortTest, Readable) {
+ EXPECT_TRUE(
+ file::Readable(TFLITE_PREFIX "toco/toco_port_test.cc", file::Defaults())
+ .ok());
+
+ EXPECT_FALSE(
+ file::Readable("non-existent_file_asldjflasdjf", file::Defaults()).ok());
+}
+
+TEST(TocoPortTest, JoinPath) {
+ EXPECT_EQ("part1/part2", file::JoinPath("part1", "part2"));
+ EXPECT_EQ("part1/part2", file::JoinPath("part1/", "part2"));
+ EXPECT_EQ("part1/part2", file::JoinPath("part1", "/part2"));
+ EXPECT_EQ("part1/part2", file::JoinPath("part1/", "/part2"));
+}
+
+} // namespace
+} // namespace port
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
new file mode 100644
index 0000000000..232538a841
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -0,0 +1,277 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/toco_tooling.h"
+
+#include <cstdlib>
+#include <memory>
+#include <set>
+
+#include "absl/strings/str_join.h"
+#include "tensorflow/contrib/lite/toco/allocate_transient_arrays.h"
+#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
+#include "tensorflow/contrib/lite/toco/export_tensorflow.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/import_tensorflow.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tflite/export.h"
+#include "tensorflow/contrib/lite/toco/tflite/import.h"
+#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+namespace {
+// CHECK-fails if the model contains a kTensorFlowUnsupported operation.
+void CheckUnsupportedOperations(const Model& model) {
+ std::set<string> unsupported_ops;
+ for (auto& op : model.operators) {
+ if (op->type == OperatorType::kTensorFlowUnsupported) {
+ unsupported_ops.insert(
+ static_cast<const TensorFlowUnsupportedOperator*>(op.get())
+ ->tensorflow_op);
+ }
+ }
+ QCHECK(unsupported_ops.empty())
+ << "These unsupported ops were not removed by graph transformations: "
+ << absl::StrJoin(unsupported_ops, ", ");
+}
+
+void MakeGeneralGraphTransformationsSet(
+ GraphTransformationsSet* transformations) {
+ CHECK(transformations->empty());
+ transformations->Add(new ResolveReshapeAttributes);
+ transformations->Add(new PropagateArrayDataTypes);
+ transformations->Add(new PropagateFixedSizes);
+ transformations->Add(new RemoveTensorFlowAssert);
+ transformations->Add(new RemoveTensorFlowIdentity);
+ transformations->Add(new RemoveTrivialConcatenation);
+ transformations->Add(new RemoveTrivialConcatenationInput);
+ transformations->Add(new RemoveUnusedOp);
+ transformations->Add(new EnsureBiasVectors);
+ transformations->Add(new ResolveReorderAxes);
+ transformations->Add(new ResolveTensorFlowMatMul);
+ transformations->Add(new FuseBinaryIntoPrecedingAffine);
+ transformations->Add(new FuseBinaryIntoFollowingAffine);
+ transformations->Add(new ResolveBatchNormalization);
+ transformations->Add(new ResolveConstantBinaryOperator);
+ transformations->Add(new ResolveConstantUnaryOperator);
+ transformations->Add(new ResolveTensorFlowMerge);
+ transformations->Add(new ResolveTensorFlowSqueeze);
+ transformations->Add(new ResolveTensorFlowSwitch);
+ transformations->Add(new ResolveTensorFlowTile);
+ transformations->Add(new ResolveTensorFlowConcat);
+ transformations->Add(new IdentifyL2Normalization);
+ transformations->Add(new IdentifyL2Pool);
+ transformations->Add(new IdentifyRelu1);
+ transformations->Add(new RemoveTrivialBinaryOperator);
+ transformations->Add(new ReadFakeQuantMinMax);
+ transformations->Add(new ResolvePadAttributes);
+ transformations->Add(new ResolveStridedSliceAttributes);
+ transformations->Add(new ResolveSliceAttributes);
+ transformations->Add(new ResolveMeanAttributes);
+ transformations->Add(new ResolveConstantTensorFlowShape);
+ transformations->Add(new MakeInitialDequantizeOperator);
+}
+
+void SetArrayFinalDataTypes(const TocoFlags& toco_flags, Model* model) {
+ const bool output_is_tflite = toco_flags.output_format() == TFLITE;
+
+ if (output_is_tflite) {
+ if (!toco_flags.input_types().empty()) {
+ for (int i = 0; i < model->flags.input_arrays_size(); i++) {
+ int input_types_index = toco_flags.input_types_size() == 1 ? 0 : i;
+ const auto input_type = toco_flags.input_types(input_types_index);
+ ArrayDataType final_data_type = ArrayDataType::kNone;
+ switch (input_type) {
+ case FLOAT:
+ final_data_type = ArrayDataType::kFloat;
+ break;
+ case QUANTIZED_UINT8:
+ final_data_type = ArrayDataType::kUint8;
+ break;
+ case INT32:
+ final_data_type = ArrayDataType::kInt32;
+ break;
+ case INT64:
+ final_data_type = ArrayDataType::kInt64;
+ break;
+ default:
+ LOG(FATAL) << "Unknown data type";
+ }
+ model->arrays[model->flags.input_arrays(i).name()]->final_data_type =
+ final_data_type;
+ }
+ }
+ } else {
+ for (int i = 0; i < model->flags.input_arrays_size(); i++) {
+ model->arrays[model->flags.input_arrays(i).name()]->final_data_type =
+ ArrayDataType::kFloat;
+ }
+ }
+}
+
+} // namespace
+
+std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
+ const ModelFlags& model_flags,
+ const string& input_file_contents) {
+ std::unique_ptr<Model> model;
+ switch (toco_flags.input_format()) {
+ case TENSORFLOW_GRAPHDEF:
+ model = ImportTensorFlowGraphDef(model_flags, input_file_contents);
+ break;
+ case TFLITE:
+ model = toco::tflite::Import(model_flags, input_file_contents);
+ ResolveModelFlags(model_flags, model.get());
+ CheckInvariants(*model);
+ break;
+ default:
+ LOG(FATAL) << "Unhandled input_format";
+ }
+
+ LogDump(kLogLevelModelChanged, "AT IMPORT", *model);
+
+ return model;
+}
+
+void Transform(const TocoFlags& toco_flags, Model* model) {
+ const FileFormat output_format = toco_flags.output_format();
+ const IODataType inference_type = toco_flags.inference_type();
+
+ const bool output_is_tflite = output_format == TFLITE;
+
+ const bool output_is_tflite_quantized =
+ output_is_tflite && inference_type == QUANTIZED_UINT8;
+
+ if (output_is_tflite) {
+ QCHECK(toco_flags.input_types_size() == 1 ||
+ toco_flags.input_types_size() == model->flags.input_arrays_size())
+ << "Mismatched numbers of input_arrays and input_types";
+ }
+
+ if (output_is_tflite_quantized) {
+ for (const auto& input_type : toco_flags.input_types()) {
+ QCHECK_NE(input_type, FLOAT)
+ << "Quantized inference is not allowed with float inputs.";
+ }
+ }
+
+ SetArrayFinalDataTypes(toco_flags, model);
+
+ GraphTransformationsSet transformations;
+ MakeGeneralGraphTransformationsSet(&transformations);
+ auto* remove_trivial_reshape = new RemoveTrivialReshape;
+ transformations.Add(remove_trivial_reshape);
+ if (output_format == TFLITE) {
+ transformations.Add(new FuseActivationFunctions);
+ } else {
+ transformations.Add(new UnfuseActivationFunctions);
+ }
+ if (output_format != TENSORFLOW_GRAPHDEF) {
+ transformations.Add(new ResolveConstantFakeQuant);
+ }
+ if (toco_flags.drop_fake_quant()) {
+ transformations.Add(new DropFakeQuant);
+ } else {
+ // See the doc for --reorder_across_fake_quant: that flag is needed to
+ // support some existing models, e.g. WordLens, that have FakeQuant
+ // nodes in the wrong places.
+ // We currently unconditionally enable that behavior when the output
+ // format is DarwiNN because the DarwiNN test code does not make it
+ // easy to pass a new toco flag. Once that is resolved on the DarwiNN
+ // tests side, the special-casing of DarwiNN here can go away.
+ // TODO(benoitjacob): so drop it when we can.
+ if ((output_is_tflite_quantized &&
+ toco_flags.reorder_across_fake_quant())) {
+ transformations.Add(new DropFakeQuant);
+ }
+ }
+ transformations.Add(new ConvertPureConvToDepthwise);
+ // TFLite export does not yet support fused LSTM cell.
+ if (output_format == TENSORFLOW_GRAPHDEF) {
+ transformations.Add(new IdentifyLstmCell);
+ }
+ transformations.Add(new ResolveConstantConcatenation);
+ RunGraphTransformations(model, "general graph transformations",
+ transformations);
+ if (output_is_tflite_quantized) {
+ RunGraphTransformations(model, "pre-quantization graph transformations",
+ {new HardcodeMinMax, new DropFakeQuant});
+ }
+
+ if (output_is_tflite_quantized) {
+ if (toco_flags.has_default_ranges_min() &&
+ toco_flags.has_default_ranges_max()) {
+ UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(),
+ toco_flags.default_ranges_max());
+ }
+ CheckIsReadyForQuantization(*model);
+ RunGraphTransformations(
+ model, "quantization graph transformations",
+ {new Quantize, new RemoveTrivialQuantizedActivationFunc,
+ new RemoveFinalDequantizeOp});
+ } else {
+ GraphTransformationsSet dequantization_transformations{new Dequantize};
+ // Dequantize creates FakeQuant nodes. We may want to discard
+ // those immediately.
+ if (toco_flags.drop_fake_quant()) {
+ dequantization_transformations.Add(new DropFakeQuant);
+ }
+
+ RunGraphTransformations(model, "dequantization graph transformations",
+ dequantization_transformations);
+ }
+
+ LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
+
+ if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) {
+ // By now there shouldn't be any unsupported ops when exporting to
+ // TensorFlow GraphDef.
+ CheckUnsupportedOperations(*model);
+ }
+
+ if (output_is_tflite) {
+ AllocateTransientArrays(model, kDefaultTransientDataAlignment);
+ LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
+ }
+
+ CheckModelCounts(*model);
+ CheckFinalDataTypesSatisfied(*model);
+
+ int64 ops_count;
+ if (EstimateArithmeticOpsCount(*model, &ops_count)) {
+ LOG(INFO) << "Estimated count of arithmetic ops: " << 1e-9 * ops_count
+ << " billion (note that a multiply-add is counted as 2 ops).";
+ }
+}
+
+void Export(const TocoFlags& toco_flags, const Model& model,
+ bool allow_custom_ops, string* output_file_contents) {
+ switch (toco_flags.output_format()) {
+ case TENSORFLOW_GRAPHDEF:
+ ExportTensorFlowGraphDef(model, output_file_contents);
+ break;
+ case TFLITE:
+ toco::tflite::Export(model, allow_custom_ops, output_file_contents);
+ break;
+ case GRAPHVIZ_DOT:
+ DumpGraphviz(model, output_file_contents);
+ break;
+ default:
+ LOG(FATAL) << "Unhandled output_format";
+ }
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.h b/tensorflow/contrib/lite/toco/toco_tooling.h
new file mode 100644
index 0000000000..9c5a93a211
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_tooling.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+
+namespace toco {
+
+// Imports the input file into a Model object.
+std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
+ const ModelFlags& model_flags,
+ const string& input_file_contents);
+
+// Transforms a Model. The resulting Model is ready to be passed
+// to Export with the exact same toco_flags.
+void Transform(const TocoFlags& toco_flags, Model* model);
+
+// Exports the Model, which must be of the 'lowered' form returned by
+// Transform, to a file of the format given by
+// toco_flags.output_format().
+void Export(const TocoFlags& toco_flags, const Model& model,
+ bool allow_custom_ops, string* output_file_contents);
+
+// This if for backward-compatibility with internal tools.
+inline void Export(const TocoFlags& toco_flags, const Model& model,
+ string* output_file_contents) {
+ Export(toco_flags, model, true, output_file_contents);
+}
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_
diff --git a/tensorflow/contrib/lite/toco/toco_types.h b/tensorflow/contrib/lite/toco/toco_types.h
new file mode 100644
index 0000000000..ad42497ada
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_types.h
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
+
+#include <string>
+#include "tensorflow/core/platform/platform.h"
+
+#if defined(PLATFORM_GOOGLE) || defined(GOOGLE_INTEGRAL_TYPES)
+#include "tensorflow/core/platform/google/integral_types.h"
+#else
+#include "tensorflow/core/platform/default/integral_types.h"
+#endif
+
+namespace toco {
+#ifdef PLATFORM_GOOGLE
+using ::string;
+#else
+using std::string;
+#endif
+
+using tensorflow::int16;
+using tensorflow::int32;
+using tensorflow::int64;
+using tensorflow::int8;
+using tensorflow::uint16;
+using tensorflow::uint32;
+using tensorflow::uint64;
+using tensorflow::uint8;
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
new file mode 100644
index 0000000000..bcbfed62d3
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -0,0 +1,1552 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+#include <functional>
+#include <iterator>
+#include <set>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_replace.h"
+#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/core/platform/logging.h"
+
+
+namespace toco {
+
+string LogName(const Operator& op) {
+ const string& opname = HelpfulOperatorTypeName(op);
+ if (op.outputs.empty()) {
+ return toco::port::StringF("{%s operator}", opname);
+ } else {
+ return toco::port::StringF("{%s operator with output %s}", opname,
+ op.outputs[0]);
+ }
+}
+
+bool IsInputArray(const Model& model, const string& name) {
+ for (const auto& input_array : model.flags.input_arrays()) {
+ if (input_array.name() == name) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool IsArrayConsumed(const Model& model, const string& name) {
+ if (GetOpWithInput(model, name)) {
+ return true;
+ }
+ for (const string& model_output : model.flags.output_arrays()) {
+ if (model_output == name) {
+ return true;
+ }
+ }
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ if (rnn_state.back_edge_source_array() == name) {
+ return true;
+ }
+ }
+ return false;
+}
+
+int CountTrueOutputs(const Model& model, const Operator& op) {
+ int count = 0;
+ for (const string& output : op.outputs) {
+ if (IsArrayConsumed(model, output)) {
+ ++count;
+ }
+ }
+ return count;
+}
+
+int CountOpsWithInput(const Model& model, const string& array_name) {
+ int count = 0;
+ for (const auto& op : model.operators) {
+ for (auto& input : op->inputs) {
+ if (input == array_name) {
+ count++;
+ }
+ }
+ }
+ return count;
+}
+
+bool DeleteArrayIfUnused(const string& array_name, Model* model) {
+ if (CountOpsWithInput(*model, array_name) == 0) {
+ model->arrays.erase(array_name);
+ return true;
+ }
+ return false;
+}
+
+std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
+ const Model& model, const string& array_name) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ for (auto& output : it->get()->outputs) {
+ if (output == array_name) {
+ return it;
+ }
+ }
+ }
+ return model.operators.end();
+}
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
+ Model& model, const string& array_name) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ for (auto& output : it->get()->outputs) {
+ if (output == array_name) {
+ return it;
+ }
+ }
+ }
+ return model.operators.end();
+}
+
+Operator* GetOpWithOutput(const Model& model, const string& array_name) {
+ auto it = FindOpWithOutput(model, array_name);
+ return it == model.operators.end() ? nullptr : it->get();
+}
+
+// GetFirstOpWithInput assumes that this finds the first op.
+std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
+ const Model& model, const string& array_name) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ for (auto& input : it->get()->inputs) {
+ if (input == array_name) {
+ return it;
+ }
+ }
+ }
+ return model.operators.end();
+}
+
+std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
+ const Model& model, const Operator* op) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ if (it->get() == op) {
+ return it;
+ }
+ }
+ return model.operators.end();
+}
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
+ const Operator* op) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ if (it->get() == op) {
+ return it;
+ }
+ }
+ return model.operators.end();
+}
+
+Operator* GetOpWithInput(const Model& model, const string& array_name) {
+ auto it = FindOpWithInput(model, array_name);
+ return it == model.operators.end() ? nullptr : it->get();
+}
+
+Operator* GetFirstOpWithInput(const Model& model, const string& array_name) {
+ auto it = FindOpWithInput(model, array_name);
+ return it == model.operators.end() ? nullptr : it->get();
+}
+
+string FormatArraysList(const Model& model, const std::vector<string>& list) {
+ if (list.empty()) {
+ return "[]";
+ }
+ string result = "";
+ if (list.size() > 1) {
+ result += "[ ";
+ }
+ for (std::size_t i = 0; i < list.size(); i++) {
+ if (i > 0) {
+ result += ", ";
+ }
+ result += list[i];
+ }
+ if (list.size() > 1) {
+ result += " ]";
+ }
+ return result;
+}
+
+const char* OperatorTypeName(OperatorType type) {
+ switch (type) {
+#define HANDLE_OPERATORTYPENAME_CASE(c) \
+ case OperatorType::k##c: \
+ return #c;
+ HANDLE_OPERATORTYPENAME_CASE(Add)
+ HANDLE_OPERATORTYPENAME_CASE(AveragePool)
+ HANDLE_OPERATORTYPENAME_CASE(BatchNormalization)
+ HANDLE_OPERATORTYPENAME_CASE(Conv)
+ HANDLE_OPERATORTYPENAME_CASE(Concatenation)
+ HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv)
+ HANDLE_OPERATORTYPENAME_CASE(DepthToSpace)
+ HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth)
+ HANDLE_OPERATORTYPENAME_CASE(FullyConnected)
+ HANDLE_OPERATORTYPENAME_CASE(Dequantize)
+ HANDLE_OPERATORTYPENAME_CASE(L2Normalization)
+ HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization)
+ HANDLE_OPERATORTYPENAME_CASE(Logistic)
+ HANDLE_OPERATORTYPENAME_CASE(LstmCell)
+ HANDLE_OPERATORTYPENAME_CASE(MaxPool)
+ HANDLE_OPERATORTYPENAME_CASE(L2Pool)
+ HANDLE_OPERATORTYPENAME_CASE(FakeQuant)
+ HANDLE_OPERATORTYPENAME_CASE(Mul)
+ HANDLE_OPERATORTYPENAME_CASE(Relu)
+ HANDLE_OPERATORTYPENAME_CASE(Relu1)
+ HANDLE_OPERATORTYPENAME_CASE(Relu6)
+ HANDLE_OPERATORTYPENAME_CASE(ReorderAxes)
+ HANDLE_OPERATORTYPENAME_CASE(Softmax)
+ HANDLE_OPERATORTYPENAME_CASE(Div)
+ HANDLE_OPERATORTYPENAME_CASE(Tanh)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreater)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreaterEqual)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowIdentity)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowLess)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowLessEqual)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowMatMul)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowMax)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowMaximum)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowMerge)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowMin)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum)
+ HANDLE_OPERATORTYPENAME_CASE(Pad)
+ HANDLE_OPERATORTYPENAME_CASE(StridedSlice)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowReshape)
+ HANDLE_OPERATORTYPENAME_CASE(Squeeze)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowRsqrt)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowShape)
+ HANDLE_OPERATORTYPENAME_CASE(Slice)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowSplit)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowSqrt)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowSquare)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowSwitch)
+ HANDLE_OPERATORTYPENAME_CASE(Sub)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowSum)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowTile)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcat)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcatV2)
+ HANDLE_OPERATORTYPENAME_CASE(Cast)
+ HANDLE_OPERATORTYPENAME_CASE(Floor)
+ HANDLE_OPERATORTYPENAME_CASE(Gather)
+ HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear)
+ HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND)
+ HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND)
+ HANDLE_OPERATORTYPENAME_CASE(Mean)
+ HANDLE_OPERATORTYPENAME_CASE(Svdf)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported)
+ default:
+ LOG(FATAL) << "Unhandled op type";
+#undef HANDLE_OPERATORTYPENAME_CASE
+ }
+}
+
+string HelpfulOperatorTypeName(const Operator& op) {
+ if (op.type == OperatorType::kTensorFlowUnsupported) {
+ return toco::port::StringF(
+ "(Unsupported TensorFlow op: %s)",
+ static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op);
+ }
+ return OperatorTypeName(op.type);
+}
+
+void LogSummary(int log_level, const Model& model) {
+ VLOG(log_level) << "Operators summary (" << model.operators.size()
+ << " operators): ";
+ std::unordered_multiset<OperatorType> ops_by_type;
+ for (const auto& op : model.operators) {
+ ops_by_type.insert(op->type);
+ }
+ auto it = ops_by_type.begin();
+ while (it != ops_by_type.end()) {
+ int count = ops_by_type.count(*it);
+ VLOG(log_level) << " " << OperatorTypeName(*it) << ": " << count;
+ std::advance(it, count);
+ }
+}
+
+void LogArray(int log_level, const Model& model, const string& name) {
+ const auto& array = model.GetArray(name);
+ VLOG(log_level) << "Array: " << name;
+ switch (array.data_type) {
+ case ArrayDataType::kNone:
+ break;
+ case ArrayDataType::kFloat:
+ VLOG(log_level) << " Data type: kFloat";
+ break;
+ case ArrayDataType::kInt32:
+ VLOG(log_level) << " Data type: kInt32";
+ break;
+ case ArrayDataType::kUint8:
+ VLOG(log_level) << " Data type: kUint8";
+ break;
+ default:
+ VLOG(log_level) << " Data type: other (numerical value: "
+ << static_cast<int>(array.data_type) << ")";
+ break;
+ }
+ if (array.buffer) {
+ VLOG(log_level) << " Constant Buffer";
+ }
+ if (array.alloc) {
+ VLOG(log_level) << " Transient Alloc";
+ }
+ if (array.has_shape()) {
+ const Shape& array_shape = array.shape();
+ if (array_shape.dimensions_count() == 0) {
+ VLOG(log_level) << " (Zero dimensions)";
+ } else {
+ string message = " Dims: ";
+ bool first = true;
+ for (const int dim : array_shape.dims()) {
+ if (!first) {
+ message += ", ";
+ }
+ first = false;
+ toco::port::AppendF(&message, "%d", dim);
+ }
+ VLOG(log_level) << message;
+ }
+ }
+ if (array.minmax) {
+ VLOG(log_level) << " MinMax: " << array.minmax->min << " .. "
+ << array.minmax->max;
+ }
+ if (array.quantization_params) {
+ VLOG(log_level) << " QuantizationParams: zero_point="
+ << array.quantization_params->zero_point
+ << ", scale=" << array.quantization_params->scale;
+ }
+}
+
+void DumpGraphvizVideoFrame(const Model& model) {
+ namespace port = toco::port;
+
+ const auto& dump_options = *GraphVizDumpOptions::singleton();
+ if (!dump_options.dump_graphviz_video) {
+ return;
+ }
+ CHECK(!dump_options.dump_graphviz.empty());
+ // TODO(benoitjacob): the static data here means that this function
+ // is stateful, not reentrant, and effectively leaks memory till exit
+ // (since dump_hashes can only grow in size). It also means that it
+ // really only is intended to be called for a single model during the
+ // process' lifetime. So it's not great design at all. The overriding
+ // design aspect here is to make the video-dumping code as unintrusive
+ // and self-contained as possible. Eventually, we'll want to have that
+ // cleaned-up, but that will require some form of general statefulness
+ // in toco (some kind of 'tooling state' data structure) that does
+ // not exist at present, and would be premature to design here just for
+ // this new video-dumping feature.
+ static int dump_id = 0;
+ static std::unordered_set<std::size_t> dump_hashes;
+ string graphviz_dump;
+ DumpGraphviz(model, &graphviz_dump);
+ std::size_t hash = std::hash<string>{}(graphviz_dump);
+ if (!dump_hashes.count(hash)) {
+ dump_hashes.insert(hash);
+ CHECK(port::file::SetContents(
+ port::file::JoinPath(
+ dump_options.dump_graphviz,
+ toco::port::StringF("toco_video_%05d.dot", dump_id)),
+ graphviz_dump, port::file::Defaults())
+ .ok());
+ dump_id++;
+ }
+}
+
+void LogDump(int log_level, const string& message, const Model& model) {
+ namespace port = toco::port;
+ const auto& dump_options = *GraphVizDumpOptions::singleton();
+
+ DumpGraphvizVideoFrame(model);
+ if (!dump_options.dump_graphviz.empty()) {
+ string graphviz_dump;
+
+ DumpGraphviz(model, &graphviz_dump);
+ CHECK(port::file::SetContents(
+ port::file::JoinPath(
+ dump_options.dump_graphviz,
+ absl::StrCat("toco_",
+ absl::StrReplaceAll(message, {{" ", "_"}}),
+ ".dot")),
+ graphviz_dump, port::file::Defaults())
+ .ok());
+ }
+
+ if (!VLOG_IS_ON(log_level)) {
+ return;
+ }
+ VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")";
+ LogSummary(log_level, model);
+ std::unordered_set<string> already_printed_arrays;
+ for (const auto& op : model.operators) {
+ for (const auto& input : op->inputs) {
+ if (!already_printed_arrays.count(input)) {
+ already_printed_arrays.insert(input);
+ LogArray(log_level, model, input);
+ }
+ }
+ VLOG(log_level) << HelpfulOperatorTypeName(*op) << " : ";
+ VLOG(log_level) << " " << FormatArraysList(model, op->inputs) << " -> "
+ << FormatArraysList(model, op->outputs);
+ if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
+ VLOG(log_level) << " (with fused activation function)";
+ }
+ for (const auto& output : op->outputs) {
+ if (!already_printed_arrays.count(output)) {
+ already_printed_arrays.insert(output);
+ LogArray(log_level, model, output);
+ }
+ }
+ }
+ VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")";
+}
+
+// Note remaining raw-array extension in ProcessTensorFlowReshapeOperator().
+void ExtendShape(Shape* shape, int new_shape_size) {
+ CHECK_GE(new_shape_size, shape->dimensions_count());
+ const int size_increase = new_shape_size - shape->dimensions_count();
+ auto* shape_dims = shape->mutable_dims();
+ shape_dims->insert(shape_dims->begin(), size_increase, 1);
+}
+
+// TODO(b/62904716) Remove along with remaining uses.
+void UnextendShape(Shape* shape, int new_shape_size) {
+ CHECK_LE(new_shape_size, shape->dimensions_count());
+ const int size_reduction = shape->dimensions_count() - new_shape_size;
+ for (int i = 0; i < size_reduction; i++) {
+ CHECK_EQ(shape->dims(i), 1);
+ }
+ std::vector<int>& shape_dims = *shape->mutable_dims();
+ shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction);
+}
+
+void CheckShapeDimensions(const Shape& shape) {
+ for (int i = 0; i < shape.dimensions_count(); ++i) {
+ CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i
+ << ". shape = " << ShapeToString(shape);
+ }
+}
+
+bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
+ CheckShapeDimensions(shape0);
+ CheckShapeDimensions(shape1);
+
+ const Shape* longer = &shape0;
+ const Shape* shorter = &shape1;
+ if (shape1.dimensions_count() > shape0.dimensions_count()) {
+ longer = &shape1;
+ shorter = &shape0;
+ }
+
+ // Walk dimensions back to front until we run out of dimensions in the shorter
+ // shape.
+ int longer_index = longer->dimensions_count() - 1;
+ int shorter_index = shorter->dimensions_count() - 1;
+ while (shorter_index >= 0) {
+ const int d_long = longer->dims(longer_index);
+ const int d_short = shorter->dims(shorter_index);
+ // Broadcasting fails if the dimensions are different *and* neither is 1.
+ if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) {
+ return false;
+ }
+ longer_index--;
+ shorter_index--;
+ }
+ return true;
+}
+
+bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
+ CheckShapeDimensions(shape0);
+ CheckShapeDimensions(shape1);
+
+ const Shape* longer = &shape0;
+ const Shape* shorter = &shape1;
+ if (shape1.dimensions_count() > shape0.dimensions_count()) {
+ longer = &shape1;
+ shorter = &shape0;
+ }
+
+ // Walk dimensions back to front until we run out of dimensions in the shorter
+ // shape.
+ int longer_index = longer->dimensions_count() - 1;
+ int shorter_index = shorter->dimensions_count() - 1;
+ while (shorter_index >= 0) {
+ const int d_long = longer->dims(longer_index);
+ const int d_short = shorter->dims(shorter_index);
+ // Extending fails if the dimensions are different.
+ if (d_long != d_short) {
+ return false;
+ }
+ longer_index--;
+ shorter_index--;
+ }
+
+ // The remaining dimensions in the longer shape must be 1.
+ while (longer_index >= 0) {
+ const int d_long = longer->dims(longer_index);
+ if (d_long != 1) {
+ return false;
+ }
+ longer_index--;
+ }
+
+ return true;
+}
+
+int RequiredBufferSizeForShape(const Shape& shape) {
+ int max_offset = 1;
+ for (const auto& dim : shape.dims()) {
+ CHECK_GE(dim, 1);
+ max_offset *= dim;
+ }
+ return max_offset;
+}
+
+bool IsConstantParameterArray(const Model& model, const string& name) {
+ if (!model.arrays.count(name)) {
+ return false;
+ }
+
+ return !!model.arrays.at(name)->buffer;
+}
+
+void CheckNoMissingArray(const Model& model) {
+ for (const auto& op : model.operators) {
+ for (const auto& input : op->inputs) {
+ CHECK(model.arrays.count(input));
+ }
+ for (const auto& output : op->outputs) {
+ CHECK(model.arrays.count(output));
+ }
+ }
+ for (const auto& input_array : model.flags.input_arrays()) {
+ CHECK(model.arrays.count(input_array.name()))
+ << "Input array not found: " << input_array.name();
+ }
+ for (const string& output_array : model.flags.output_arrays()) {
+ CHECK(model.arrays.count(output_array))
+ << "Output array not found: " << output_array;
+ }
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ CHECK(model.arrays.count(rnn_state.state_array()));
+ CHECK(model.arrays.count(rnn_state.back_edge_source_array()));
+ }
+}
+
+void FixNoMissingArray(Model* model) {
+ for (const auto& op : model->operators) {
+ for (const auto& input : op->inputs) {
+ if (!model->arrays.count(input)) {
+ model->GetOrCreateArray(input);
+ }
+ }
+ for (const auto& output : op->outputs) {
+ if (!model->arrays.count(output)) {
+ model->GetOrCreateArray(output);
+ }
+ }
+ }
+ for (const string& output_array : model->flags.output_arrays()) {
+ if (!model->arrays.count(output_array)) {
+ model->GetOrCreateArray(output_array);
+ }
+ }
+}
+
+void CheckNoOrphanedArray(const Model& model) {
+ std::unordered_set<string> arrays_without_known_use;
+ for (const auto& array : model.arrays) {
+ arrays_without_known_use.insert(array.first);
+ }
+ for (const auto& op : model.operators) {
+ for (const auto& input : op->inputs) {
+ arrays_without_known_use.erase(input);
+ }
+ for (const auto& output : op->outputs) {
+ arrays_without_known_use.erase(output);
+ }
+ }
+ if (!arrays_without_known_use.empty()) {
+ for (const auto& array : arrays_without_known_use) {
+ LOG(INFO) << "Error: Orphaned array: " << array;
+ }
+ }
+ CHECK(arrays_without_known_use.empty());
+}
+
+void FixNoOrphanedArray(Model* model) {
+ std::unordered_set<string> arrays_without_known_use;
+ for (const auto& array : model->arrays) {
+ arrays_without_known_use.insert(array.first);
+ }
+ for (const auto& op : model->operators) {
+ for (const auto& input : op->inputs) {
+ arrays_without_known_use.erase(input);
+ }
+ for (const auto& output : op->outputs) {
+ arrays_without_known_use.erase(output);
+ }
+ }
+ for (const auto& array : arrays_without_known_use) {
+ model->arrays.erase(array);
+ }
+}
+
+void CheckArrayFieldsConsistent(const Model& model) {
+ for (const auto& array_entry : model.arrays) {
+ const auto& array = array_entry.second;
+ if (array->has_shape()) {
+ for (int d : array->shape().dims()) {
+ CHECK_GE(d, 1);
+ }
+ }
+ // It's OK to have a buffer or an alloc, but not both.
+ // (Since allocs are for transient arrays without a buffer).
+ CHECK(!array->buffer || !array->alloc);
+ // If there is a buffer, its type should be consistent with data_type.
+ if (array->buffer) {
+ CHECK(array->buffer->type == array->data_type);
+ }
+ }
+}
+
+void CheckOperatorOrdering(const Model& model) {
+ std::unordered_set<string> arrays_behind_us;
+ for (const auto& array_entry : model.arrays) {
+ if (!GetOpWithOutput(model, array_entry.first)) {
+ arrays_behind_us.insert(array_entry.first);
+ }
+ }
+ for (const auto& op : model.operators) {
+ for (const auto& input : op->inputs) {
+ if (!IsConstantParameterArray(model, input)) {
+ CHECK(arrays_behind_us.count(input));
+ }
+ }
+ for (const auto& output : op->outputs) {
+ CHECK(!arrays_behind_us.count(output));
+ arrays_behind_us.insert(output);
+ }
+ }
+ for (const string& output_array : model.flags.output_arrays()) {
+ CHECK(arrays_behind_us.count(output_array));
+ }
+}
+
+void FixOperatorOrdering(Model* model) {
+ std::unordered_set<string> arrays_behind_us;
+ for (const auto& array_entry : model->arrays) {
+ if (!GetOpWithOutput(*model, array_entry.first)) {
+ arrays_behind_us.insert(array_entry.first);
+ }
+ }
+ std::vector<std::unique_ptr<Operator>> old_operators;
+ std::swap(old_operators, model->operators);
+ std::set<std::size_t> remaining;
+ for (std::size_t i = 0; i < old_operators.size(); i++) {
+ remaining.insert(i);
+ }
+ std::unordered_map<string, string> reason_why_leftover;
+ while (true) {
+ bool inserted_something = false;
+ for (auto i : remaining) {
+ bool can_insert = true;
+ auto& op = old_operators[i];
+ CHECK(op.get());
+ for (const auto& input : op->inputs) {
+ if (!IsConstantParameterArray(*model, input) &&
+ !arrays_behind_us.count(input)) {
+ for (const string& output : op->outputs) {
+ reason_why_leftover[output] = input;
+ }
+ can_insert = false;
+ break;
+ }
+ }
+ if (can_insert) {
+ model->operators.emplace_back(nullptr);
+ for (const auto& output : op->outputs) {
+ arrays_behind_us.insert(output);
+ }
+ std::swap(op, model->operators.back());
+ remaining.erase(i);
+ inserted_something = true;
+ break;
+ }
+ }
+ if (!inserted_something) {
+ break;
+ }
+ }
+ if (!remaining.empty()) {
+ LOG(ERROR)
+ << "No viable ordering of operators was found. "
+ << "Here is a 'backtrace' of at least one part of the graph that is "
+ << "problematic. It starts with the first operator that has as "
+ << "problematic input array, and then walks back the graph to "
+ << "the operator that produced that input array, etc., until we find "
+ << "the root cause:";
+ LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT";
+ LOG(ERROR) << "Here is the first-encountered operator with a bad input: ";
+ const Operator* bad_op = old_operators[*remaining.begin()].get();
+ std::unordered_set<string> bad_inputs_already_traced;
+ // The following while(true) loop should always end with a LOG(FATAL).
+ while (true) {
+ LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : "
+ << FormatArraysList(*model, bad_op->inputs) << " -> "
+ << FormatArraysList(*model, bad_op->outputs);
+ bool found_bad_output = false;
+ string bad_output;
+ for (const string& output : bad_op->outputs) {
+ if (reason_why_leftover.count(output)) {
+ found_bad_output = true;
+ bad_output = output;
+ break;
+ }
+ }
+ CHECK(found_bad_output);
+ const string& bad_input = reason_why_leftover[bad_output];
+ LOG(ERROR) << "The bad input here is: " << bad_input;
+ if (bad_inputs_already_traced.count(bad_input)) {
+ LOG(FATAL)
+ << "Cycle found! We already encountered that "
+ << "input array, " << bad_input << ", earlier in the "
+ << "above trace! We expect graphs to be acyclic, even "
+ << "RNNs. Let us know if some graph actually needs to have "
+ << "cycles, but first, please check if it really is "
+ << "an *inference* graph. *Training* graphs are out-of-scope "
+ << "for toco.";
+ }
+ bad_inputs_already_traced.insert(bad_input);
+ bad_op = nullptr;
+ for (auto i : remaining) {
+ const Operator* op = old_operators[i].get();
+ for (const string& output : op->outputs) {
+ if (bad_input == output) {
+ bad_op = op;
+ break;
+ }
+ }
+ if (bad_op) {
+ break;
+ }
+ }
+ if (!bad_op) {
+ LOG(ERROR) << "And that's the root cause: "
+ << "that array, " << bad_input << ", isn't produced by any "
+ << "operator, or provided in any other way.";
+ LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT";
+ LOG(FATAL) << "(The above was a multi-line fatal error)";
+ }
+ LOG(ERROR) << "And that array is the output of the following operator:";
+ }
+ }
+ CHECK(remaining.empty())
+ << "Should never get here! In case of bad graph, "
+ << "the above code should have generated a FATAL error already!";
+}
+
+// Checks that the --input_arrays of the Model are actually used by at least
+// one of the --output_arrays i.e. that the graph contains a path from each one
+// of the inputs to at least one of the outputs. This catches cases where the
+// user passed the wrong --input_arrays or --output_arrays, which otherwise may
+// result in cryptic error messages.
+void CheckInputUsedByOutputs(const Model& model) {
+ std::set<string> used_arrays;
+ for (const string& output : model.flags.output_arrays()) {
+ used_arrays.insert(output);
+ }
+ for (int i = model.operators.size() - 1; i >= 0; i--) {
+ bool is_op_used = false;
+ for (const string& op_output : model.operators[i]->outputs) {
+ if (used_arrays.count(op_output)) {
+ is_op_used = true;
+ break;
+ }
+ }
+ if (!is_op_used) {
+ continue;
+ }
+ for (const string& op_input : model.operators[i]->inputs) {
+ used_arrays.insert(op_input);
+ }
+ }
+ for (const auto& input_array : model.flags.input_arrays()) {
+ QCHECK(used_arrays.count(input_array.name()))
+ << "The graph does not connect the input (" << input_array.name()
+ << ") specified by --input_arrays to any of the specified "
+ << "--output_arrays ("
+ << absl::StrJoin(model.flags.output_arrays(), ", ")
+ << "). Did you pass the wrong flags for this model, "
+ << "or is that model's graph actually incomplete?";
+ }
+}
+
+void CheckInvariants(const Model& model) {
+ CheckNoMissingArray(model);
+ CheckNoOrphanedArray(model);
+ CheckArrayFieldsConsistent(model);
+ CheckOperatorOrdering(model);
+ CheckInputUsedByOutputs(model);
+}
+
+void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check,
+ const int count, const string& count_description) {
+ if (model_check.count_min() >= 0) {
+ CHECK_GE(count, model_check.count_min())
+ << "Mismatch in " << count_description << ": count was " << count
+ << ", but the specified "
+ << (model_check.count_max() > model_check.count_min() ? "minimum"
+ : "value")
+ << " was " << model_check.count_min() << ".";
+ }
+ if (model_check.count_max() > model_check.count_min()) {
+ CHECK_LE(count, model_check.count_max())
+ << "Mismatch in " << count_description << ": count was " << count
+ << ", but the specified maximum was " << model_check.count_max() << ".";
+ }
+}
+
+void CheckModelCounts(const Model& model) {
+ std::unordered_multiset<OperatorType> ops_by_type;
+ std::unordered_map<string, OperatorType> op_type_by_name;
+ if (model.flags.model_checks_size() == 0) {
+ return;
+ }
+
+ for (const auto& op : model.operators) {
+ ops_by_type.insert(op->type);
+ op_type_by_name[OperatorTypeName(op->type)] = op->type;
+ }
+ for (const auto& model_check : model.flags.model_checks()) {
+ string count_type = model_check.count_type();
+ if (count_type == "None") {
+ continue;
+ } else if (count_type == "Arrays") {
+ CheckCountInRange(model_check, model.arrays.size(), "count of arrays");
+ } else if (count_type == "Total") {
+ CheckCountInRange(model_check, model.operators.size(),
+ "count of all operator instances");
+ } else {
+ // The check type is not itself checked against the set of valid
+ // operators, mainly because the enum set cannot be iterated in C++.
+ const int found_count =
+ op_type_by_name.count(count_type) > 0
+ ? ops_by_type.count(op_type_by_name[count_type])
+ : 0;
+ CheckCountInRange(model_check, found_count,
+ "count of instances of " + count_type + " operator");
+ }
+ }
+}
+
+void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
+ std::vector<int>* out_dims) {
+ CHECK(out_dims->empty());
+ if (num_dims == 1) {
+ CHECK_EQ(batch, 1);
+ *out_dims = {depth};
+ } else if (num_dims == 2) {
+ *out_dims = {batch, depth};
+ } else if (num_dims == 3) {
+ CHECK_EQ(batch, 1);
+ *out_dims = {height, width, depth};
+ } else if (num_dims == 4) {
+ *out_dims = {batch, height, width, depth};
+ } else {
+ LOG(FATAL) << "Should not get here: " << num_dims;
+ }
+}
+
+void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) {
+ int batch = 1;
+ int num_dims = -1;
+ for (const auto& input_array : model->flags.input_arrays()) {
+ // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
+ // a better match by name.
+ if (input_array.name() == name || num_dims == -1) {
+ num_dims = input_array.shape_size();
+ if (num_dims != 0) {
+ batch = input_array.shape(0);
+ }
+ }
+ }
+ Array& array = model->GetOrCreateArray(name);
+ if (array.has_shape()) {
+ num_dims = array.shape().dimensions_count();
+ }
+ std::vector<int> dims;
+ MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
+ CHECK(array.data_type == ArrayDataType::kFloat ||
+ array.data_type == ArrayDataType::kNone);
+ array.data_type = ArrayDataType::kFloat;
+ if (!array.has_shape()) {
+ Shape* shape = array.mutable_shape();
+ *shape->mutable_dims() = dims;
+ }
+}
+
+void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
+ // Merge info about input_arrays from model_flags into model->flags
+ for (const auto& specified_input_array : model_flags.input_arrays()) {
+ toco::InputArray* dst_input_array = nullptr;
+ for (int i = 0; i < model->flags.input_arrays_size(); i++) {
+ toco::InputArray* candidate_dst_input_array =
+ model->flags.mutable_input_arrays(i);
+ if (candidate_dst_input_array->name() == specified_input_array.name()) {
+ // specified_input_array from model_flags maps to dst_input_array
+ // in model->flags
+ dst_input_array = candidate_dst_input_array;
+ break;
+ }
+ }
+ if (!dst_input_array) {
+ // specified_input_array from model_flags is not found in model->flags.
+ // Match a name-less specified input array when there can be no ambiguity
+ // as there is only 1 input array.
+ if (model->flags.input_arrays_size() == 1 &&
+ model_flags.input_arrays_size() == 1 &&
+ !specified_input_array.has_name()) {
+ dst_input_array = model->flags.mutable_input_arrays(0);
+ }
+ }
+ if (!dst_input_array) {
+ // Still no match, so create a new input array to copy
+ // specified_input_array into.
+ dst_input_array = model->flags.add_input_arrays();
+ dst_input_array->set_name(specified_input_array.name());
+ }
+
+#define RESOLVE_MODEL_FLAG(field_name) \
+ if (specified_input_array.has_##field_name()) { \
+ if (dst_input_array->has_##field_name()) { \
+ QCHECK_EQ(dst_input_array->field_name(), \
+ specified_input_array.field_name()) \
+ << "For input array '" << dst_input_array->name() << "', " \
+ << "specified " #field_name " flag with value: " \
+ << specified_input_array.field_name() \
+ << " does not agree with already defined " #field_name \
+ " of this model, with value: " \
+ << specified_input_array.field_name(); \
+ } else { \
+ dst_input_array->set_##field_name(specified_input_array.field_name()); \
+ } \
+ }
+ RESOLVE_MODEL_FLAG(std_value);
+ RESOLVE_MODEL_FLAG(mean_value);
+#undef RESOLVE_MODEL_FLAG
+
+ if (!specified_input_array.shape().empty()) {
+ if (!dst_input_array->shape().empty()) {
+ QCHECK_EQ(specified_input_array.shape().size(),
+ dst_input_array->shape().size())
+ << "For input array '" << specified_input_array.name() << "', "
+ << "size of specified input shape flag with size: "
+ << specified_input_array.shape().size()
+ << " does not agree with already defined input shape"
+ " of this model, with size: "
+ << dst_input_array->shape().size();
+ // We treat the first dimension as a special case, since it is often
+ // a batch size and the input_shape flag is effectively overriding
+ // the model.
+ for (int i = 1; i < specified_input_array.shape().size(); i++) {
+ QCHECK_EQ(specified_input_array.shape().Get(i),
+ dst_input_array->shape().Get(i))
+ << "At dimension number " << i << " of input array "
+ << specified_input_array.name() << ", the specified shape's "
+ << "dimension flag with dimension: "
+ << specified_input_array.shape().Get(i)
+ << " does not agree with already defined shape"
+ << " of this model, with dimension: "
+ << dst_input_array->shape().Get(i);
+ }
+ } else {
+ dst_input_array->mutable_shape()->CopyFrom(
+ specified_input_array.shape());
+ }
+ }
+ }
+
+ if (model_flags.output_arrays_size() > 0) {
+ model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays());
+ }
+
+#define RESOLVE_MODEL_FLAG(name) \
+ if (model_flags.has_##name()) { \
+ if (model->flags.has_##name()) { \
+ QCHECK_EQ(model_flags.name(), model->flags.name()) \
+ << "Specified " #name " flag with value: " << model_flags.name() \
+ << " does not agree with already defined " #name \
+ " of this model, with value: " \
+ << model->flags.name(); \
+ } else { \
+ model->flags.set_##name(model_flags.name()); \
+ } \
+ }
+
+ RESOLVE_MODEL_FLAG(variable_batch)
+ RESOLVE_MODEL_FLAG(drop_control_dependency)
+
+#undef RESOLVE_MODEL_FLAG
+
+ if (model->flags.rnn_states_size() == 0) {
+ model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states());
+ } else {
+ CHECK_EQ(model->flags.rnn_states_size(), model_flags.rnn_states_size());
+ for (int i = 0; i < model->flags.rnn_states_size(); i++) {
+ CHECK_EQ(model->flags.rnn_states(i).state_array(),
+ model_flags.rnn_states(i).state_array());
+ CHECK_EQ(model->flags.rnn_states(i).back_edge_source_array(),
+ model_flags.rnn_states(i).back_edge_source_array());
+ }
+ }
+
+ if (model->flags.model_checks_size() == 0) {
+ model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks());
+ }
+
+ QCHECK_GT(model->flags.input_arrays_size(), 0)
+ << "This model does not define input arrays, so a "
+ "--input_arrays flag must be given on the command-line.";
+ QCHECK_GT(model->flags.output_arrays_size(), 0)
+ << "This model does not define output arrays, so a "
+ "--output_arrays flag must be given on the command-line.";
+
+ for (const auto& input_array_proto : model->flags.input_arrays()) {
+ QCHECK(!input_array_proto.shape().empty())
+ << "This model does not have shape defined for input array "
+ << input_array_proto.name()
+ << ", so one must be specified by a non-empty --input_shape "
+ "command-line flag.";
+
+ auto& input_array = model->GetOrCreateArray(input_array_proto.name());
+ if (input_array.data_type == ArrayDataType::kNone) {
+ // We start out with a float input array;
+ // that may get replaced by a uint8 array later, by
+ // MakeInitialDequantizeOp.
+ input_array.data_type = ArrayDataType::kFloat;
+ }
+
+ // Compare/merge the model->flags describing the input_shape with
+ // the actual input array's shape.
+ auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
+ if (input_array_dims.empty()) {
+ for (auto dim : input_array_proto.shape()) {
+ CHECK_GE(dim, 1);
+ input_array_dims.push_back(dim);
+ }
+ } else {
+ CHECK_EQ(input_array_dims.size(), input_array_proto.shape_size());
+ for (int i = 0; i < input_array_dims.size(); i++) {
+ CHECK_EQ(input_array_dims[i], input_array_proto.shape(i));
+ }
+ }
+
+ const float mean_value = input_array_proto.mean_value();
+ const float std_value = input_array_proto.std_value();
+ MinMax input_minmax;
+ input_minmax.min = (0.f - mean_value) / std_value;
+ input_minmax.max = (255.f - mean_value) / std_value;
+ if (input_array.minmax) {
+ if (input_array_proto.has_mean_value() ||
+ input_array_proto.has_std_value()) {
+ CHECK(input_minmax == *input_array.minmax)
+ << input_minmax.min << ", " << input_minmax.max
+ << " != " << input_array.minmax->min << ", "
+ << input_array.minmax->max;
+ }
+ } else {
+ input_array.GetOrCreateMinMax() = input_minmax;
+ }
+ }
+ // Creation of the RNN state arrays
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (!rnn_state.manually_create()) {
+ continue;
+ }
+ CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
+ model);
+ }
+}
+
+void CheckIsReadyForQuantization(const Model& model) {
+ for (const auto& op : model.operators) {
+ for (const auto& input : op->inputs) {
+ const auto& input_array = model.GetArray(input);
+ if (input_array.data_type != ArrayDataType::kFloat) {
+ // The array is not floats, no quantization needed.
+ continue;
+ }
+ if (input_array.minmax) {
+ // The array has minmax, we're good.
+ continue;
+ }
+ if (input_array.buffer) {
+ // The array has a constant buffer, so we can
+ // fall back to computing the minmax from actual array entries
+ // (with a WARNING about possible accuracy implications).
+ continue;
+ }
+ LOG(FATAL)
+ << "Array " << input << ", which is an input to the "
+ << HelpfulOperatorTypeName(*op) << " operator producing the output "
+ << "array " << op->outputs[0] << ", is lacking min/max data, "
+ << "which is necessary for quantization. Either target a "
+ << "non-quantized output format, or change the input graph to "
+ << "contain min/max information, or pass --default_ranges_min= and "
+ << "--default_ranges_max= if you do not care about the accuracy of "
+ << "results.";
+ }
+ }
+}
+
+void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min,
+ double default_ranges_max) {
+ for (const auto& op : model->operators) {
+ for (const auto& input : op->inputs) {
+ auto& input_array = model->GetArray(input);
+ if (!input_array.minmax && !input_array.buffer) {
+ auto& minmax = input_array.GetOrCreateMinMax();
+ minmax.min = default_ranges_min;
+ minmax.max = default_ranges_max;
+ }
+ }
+ for (const auto& output : op->outputs) {
+ auto& output_array = model->GetArray(output);
+ if (!output_array.minmax && !output_array.buffer) {
+ auto& minmax = output_array.GetOrCreateMinMax();
+ minmax.min = default_ranges_min;
+ minmax.max = default_ranges_max;
+ }
+ }
+ }
+}
+
+int ElementSize(ArrayDataType data_type) {
+ switch (data_type) {
+ case ArrayDataType::kFloat:
+ return 4;
+ case ArrayDataType::kInt32:
+ return 4;
+ case ArrayDataType::kUint8:
+ return 1;
+ default:
+ LOG(FATAL) << "Should not get here.";
+ return 0;
+ }
+}
+
+void DropMinMax(Model* model, const string& array_name) {
+ auto& array = model->GetArray(array_name);
+ if (!!array.minmax) {
+ LOG(WARNING) << "Dropping MinMax information in array " << array_name
+ << ". Expect inaccuracy in quantized inference.";
+ array.minmax = nullptr;
+ }
+}
+
+bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
+ // The model's input and output arrays are externally allocated.
+ // They are not transient arrays.
+ if (IsInputArray(model, array_name)) {
+ return false;
+ }
+ for (const string& output_array : model.flags.output_arrays()) {
+ if (array_name == output_array) {
+ return false;
+ }
+ }
+ const auto& array = model.arrays.at(array_name);
+ // An array with a constant buffer isn't a transient array.
+ if (!!array->buffer) {
+ return false;
+ }
+ // An array without shape isn't allocatable.
+ if (!array->has_shape()) {
+ return false;
+ }
+ return true;
+}
+
+string AvailableArrayName(const Model& model, const string& name) {
+ if (!model.arrays.count(name)) {
+ return name;
+ }
+ const int kNumSuffixesToTry = 1000;
+ for (int i = 0; i < kNumSuffixesToTry; i++) {
+ const string& name_with_suffix = toco::port::StringF("%s_%d", name, i);
+ if (!model.arrays.count(name_with_suffix)) {
+ return name_with_suffix;
+ }
+ }
+ LOG(FATAL) << "Could not find an available array name starting with " << name
+ << ". Tried " << kNumSuffixesToTry << " suffixes, all were taken!";
+ return "";
+}
+
+string ShapeToString(const Shape& shape) {
+ if (shape.dimensions_count() == 0) {
+ return "[]";
+ }
+
+ return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]");
+}
+
+void PrintArrayShape(Model* model, const string& name) {
+ if (!model->arrays[name]->has_shape()) {
+ LOG(INFO) << name << " has no shape";
+ return;
+ }
+ LOG(INFO) << name
+ << " has shape: " << ShapeToString(model->arrays[name]->shape());
+}
+
+bool IsArrayFullyConnectedWeights(const Model& model, const string& name) {
+ bool is_fc_weights = false;
+ bool is_something_else = false;
+ for (const auto& op : model.operators) {
+ for (int input_index = 0; input_index < op->inputs.size(); input_index++) {
+ if (op->inputs[input_index] == name) {
+ if (op->type == OperatorType::kFullyConnected && input_index == 1) {
+ is_fc_weights = true;
+ } else {
+ is_something_else = true;
+ }
+ }
+ }
+ }
+ CHECK(!(is_fc_weights && is_something_else));
+ return is_fc_weights;
+}
+
+bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
+ int64 total = 0;
+ for (const auto& op : model.operators) {
+ switch (op->type) {
+ case OperatorType::kFullyConnected:
+ case OperatorType::kConv:
+ case OperatorType::kDepthwiseConv: {
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ const auto& weights_array = model.GetArray(op->inputs[1]);
+ if (!output_array.has_shape() || !weights_array.has_shape()) {
+ return false;
+ }
+ int cols = 1;
+ for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
+ cols *= output_array.shape().dims(i);
+ }
+ const int64 cost_per_col =
+ 2 * RequiredBufferSizeForShape(weights_array.shape());
+ total += cost_per_col * cols;
+ if (op->inputs.size() > 2) {
+ // There is a bias vector. One more op per output value.
+ total += RequiredBufferSizeForShape(output_array.shape());
+ }
+ break;
+ }
+ case OperatorType::kAdd:
+ case OperatorType::kSub:
+ case OperatorType::kMul: {
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ total += RequiredBufferSizeForShape(output_array.shape());
+ break;
+ }
+ case OperatorType::kLogistic:
+ case OperatorType::kSoftmax:
+ case OperatorType::kTanh: {
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ // As a very rough ballpark, the cost of evaluating a math function
+ // such as tanh or logistic is about 32 multiplications, and about as
+ // many additions/subtractions. (Just a power-of-two order-of-magnitude
+ // from looking at actual implementations that we use in runtime/ code).
+ total += 64 * RequiredBufferSizeForShape(output_array.shape());
+ break;
+ }
+ case OperatorType::kMaxPool: {
+ const auto& maxpool = *static_cast<const MaxPoolOperator*>(op.get());
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ total += RequiredBufferSizeForShape(output_array.shape()) *
+ maxpool.kheight * maxpool.kwidth;
+ break;
+ }
+ case OperatorType::kAveragePool: {
+ const auto& avgpool =
+ *static_cast<const AveragePoolOperator*>(op.get());
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ total += RequiredBufferSizeForShape(output_array.shape()) *
+ avgpool.kheight * avgpool.kwidth;
+ break;
+ }
+ case OperatorType::kL2Pool: {
+ const auto* maxpool = static_cast<const MaxPoolOperator*>(op.get());
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ // The sum of squares requires (kheight*kwidth) multiply-adds,
+ // and then there is the sqrt which we ballpark at 32 ops.
+ const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
+ total +=
+ RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
+ break;
+ }
+ case OperatorType::kL2Normalization: {
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ // Computing the squared L2 norm is N multiply-adds so 2N ops,
+ // then the single inverse-sqrt is negligible, then we multiply each
+ // value by the resulting multiplier, so an extra N ops. Total 3N ops.
+ total += 3 * RequiredBufferSizeForShape(output_array.shape());
+ break;
+ }
+ default:
+ break;
+ }
+ }
+ *result = total;
+ return true;
+}
+
+namespace {
+
+void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
+ std::vector<int>* shuffle) {
+ CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order));
+ shuffle->resize(4);
+ for (int i = 0; i < 4; i++) {
+ (*shuffle)[i] = i;
+ }
+ if (input_axes_order == output_axes_order) {
+ // nothing to do
+ } else if (AxesCount(input_axes_order) == 2) {
+ shuffle->resize(2);
+ (*shuffle)[0] = 1;
+ (*shuffle)[1] = 0;
+ } else if (input_axes_order == AxesOrder::kOHWI &&
+ output_axes_order == AxesOrder::kHWIO) {
+ // 3210 <- 3210
+ // HWIO <- OHWI
+ (*shuffle)[0] = 1;
+ (*shuffle)[1] = 2;
+ (*shuffle)[2] = 3;
+ (*shuffle)[3] = 0;
+ } else if (input_axes_order == AxesOrder::kHWIO &&
+ output_axes_order == AxesOrder::kOHWI) {
+ // 3210 <- 3210
+ // OHWI <- HWIO
+ (*shuffle)[0] = 3;
+ (*shuffle)[1] = 0;
+ (*shuffle)[2] = 1;
+ (*shuffle)[3] = 2;
+ } else {
+ LOG(FATAL) << "Bad shuffle";
+ }
+}
+
+// Extend shuffle is designed to match ExtendShape, which pads the shape with
+// unit dimensions at the beginning.
+void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim,
+ std::vector<int>* extended_shuffle) {
+ *extended_shuffle = input_shuffle;
+ CHECK(newdim >= input_shuffle.size());
+ const int pad_size = newdim - input_shuffle.size();
+ extended_shuffle->resize(newdim);
+ for (int i = 0; i < pad_size; i++) {
+ (*extended_shuffle)[i] = i;
+ }
+ for (int i = pad_size; i < newdim; i++) {
+ (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size;
+ }
+}
+
+} // end anonymous namespace
+
+void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
+ AxesOrder output_axes_order, Shape* output_shape) {
+ if (input_axes_order == AxesOrder::kHWIM &&
+ output_axes_order == AxesOrder::k1HWO) {
+ // This special case isn't just a permutation, the IM pair of dims get
+ // merged into the 3 dim, so we have to special-case it.
+ *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1),
+ input_shape.dims(3) * input_shape.dims(2)});
+ } else {
+ std::vector<int> shuffle;
+ GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
+ std::vector<int>* output_dims = output_shape->mutable_dims();
+ output_dims->resize(input_shape.dimensions_count());
+ for (int i = 0; i < input_shape.dimensions_count(); i++) {
+ (*output_dims)[i] = input_shape.dims(shuffle[i]);
+ }
+ }
+}
+
+void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
+ AxesOrder output_axes_order, const Shape& output_shape,
+ const float* input_data, float* output_data) {
+ if (input_axes_order == AxesOrder::kHWIM &&
+ output_axes_order == AxesOrder::k1HWO) {
+ // This special case isn't just a permutation, the IM pair of dims get
+ // merged into the O dim, so we have to special-case it. Fortunately,
+ // as far as array shuffling is concerned, it's just the identity
+ // transformation.
+ memcpy(output_data, input_data,
+ RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0]));
+ return;
+ }
+ CHECK(input_shape.dimensions_count() == output_shape.dimensions_count());
+ const int dim = input_shape.dimensions_count();
+ CHECK_LE(dim, 4);
+ std::vector<int> shuffle;
+ GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
+ CHECK(shuffle.size() >= dim);
+ for (int i = 0; i < dim; i++) {
+ CHECK(shuffle[i] >= 0 && shuffle[i] < dim);
+ CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i));
+ }
+ Shape extended_input_shape = input_shape;
+ ExtendShape(&extended_input_shape, 4);
+ Shape extended_output_shape = output_shape;
+ ExtendShape(&extended_output_shape, 4);
+ std::vector<int> extended_shuffle;
+ ExtendShuffle(shuffle, 4, &extended_shuffle);
+
+ const std::vector<int>& extended_input_dims = extended_input_shape.dims();
+ const std::vector<int>& extended_output_dims = extended_output_shape.dims();
+
+ // TODO(starka): Rework to handle different numbers of dimensions.
+ int input_strides[4];
+ input_strides[3] = 1;
+ input_strides[2] = extended_input_dims[3];
+ input_strides[1] = input_strides[2] * extended_input_dims[2];
+ input_strides[0] = input_strides[1] * extended_input_dims[1];
+ const int input_stride_0 = input_strides[extended_shuffle[3]];
+ const int input_stride_1 = input_strides[extended_shuffle[2]];
+ const int input_stride_2 = input_strides[extended_shuffle[1]];
+ const int input_stride_3 = input_strides[extended_shuffle[0]];
+
+ const int output_size_0 = extended_output_dims[3];
+ const int output_size_1 = extended_output_dims[2];
+ const int output_size_2 = extended_output_dims[1];
+ const int output_size_3 = extended_output_dims[0];
+ const int output_stride_0 = 1;
+ const int output_stride_1 = output_size_0;
+ const int output_stride_2 = output_stride_1 * output_size_1;
+ const int output_stride_3 = output_stride_2 * output_size_2;
+
+ for (int i3 = 0; i3 < output_size_3; i3++) {
+ const float* const input_ptr_3 = input_data + i3 * input_stride_3;
+ float* const output_ptr_3 = output_data + i3 * output_stride_3;
+ for (int i2 = 0; i2 < output_size_2; i2++) {
+ const float* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2;
+ float* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2;
+ for (int i1 = 0; i1 < output_size_1; i1++) {
+ const float* input_ptr = input_ptr_2 + i1 * input_stride_1;
+ float* output_ptr = output_ptr_2 + i1 * output_stride_1;
+ float* const output_ptr_end =
+ output_ptr + output_size_0 * output_stride_0;
+ while (output_ptr != output_ptr_end) {
+ *output_ptr = *input_ptr;
+ input_ptr += input_stride_0;
+ output_ptr += output_stride_0;
+ }
+ }
+ }
+ }
+}
+
+int AxesCount(AxesOrder axes_order) {
+ switch (axes_order) {
+ case AxesOrder::kOneAxis:
+ return 1;
+ case AxesOrder::kRC:
+ return 2;
+ case AxesOrder::kCR:
+ return 2;
+ case AxesOrder::kHWIO:
+ return 4;
+ case AxesOrder::kOHWI:
+ return 4;
+ case AxesOrder::kHWIM:
+ return 4;
+ case AxesOrder::k1HWO:
+ return 4;
+ case AxesOrder::kNHWC:
+ return 4;
+ default:
+ LOG(FATAL) << "Bad AxesOrder";
+ return 0;
+ }
+}
+
+bool IsDiscardableArray(const Model& model, const string& array_name) {
+ for (const auto& input_array : model.flags.input_arrays()) {
+ if (array_name == input_array.name()) {
+ return false;
+ }
+ }
+ for (const string& output_array : model.flags.output_arrays()) {
+ if (array_name == output_array) {
+ return false;
+ }
+ }
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ if (array_name == rnn_state.state_array()) {
+ return false;
+ }
+ if (array_name == rnn_state.back_edge_source_array()) {
+ return false;
+ }
+ }
+ return true;
+}
+
+void CheckFinalDataTypesSatisfied(const Model& model) {
+ for (const auto& array_entry : model.arrays) {
+ const auto& array = *array_entry.second;
+ if (array.final_data_type != ArrayDataType::kNone) {
+ CHECK(array.final_data_type == array.data_type);
+ }
+ }
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
new file mode 100644
index 0000000000..093945edb3
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -0,0 +1,292 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
+
+#include <algorithm>
+#include <cmath>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "google/protobuf/text_format.h"
+#include "tensorflow/core/platform/logging.h"
+#if TOCO_SUPPORT_PORTABLE_PROTOS
+#include "third_party/protobuf/src/google/protobuf/text_format.h"
+#endif // TOCO_SUPPORT_PORTABLE_PROTOS
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+
+// TODO(aselle): Replace with using a container specific hash override instead.
+namespace std {
+template <>
+struct hash<toco::OperatorType> {
+ size_t operator()(const toco::OperatorType& op) const {
+ return std::hash<size_t>()(static_cast<size_t>(op));
+ }
+};
+} // namespace std
+
+namespace toco {
+
+constexpr int kLogLevelModelChanged = 1;
+constexpr int kLogLevelModelUnchanged = 2;
+
+string LogName(const Operator& op);
+
+bool IsInputArray(const Model& model, const string& name);
+bool IsArrayConsumed(const Model& model, const string& name);
+int CountTrueOutputs(const Model& model, const Operator& op);
+
+int CountOpsWithInput(const Model& model, const string& array_name);
+bool DeleteArrayIfUnused(const string& array_name, Model* model);
+
+std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
+ const Model& model, const string& array_name);
+Operator* GetOpWithOutput(const Model& model, const string& array_name);
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
+ Model& model, const string& array_name);
+Operator* GetOpWithOutput(const Model& model, const string& array_name);
+
+std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
+ const Model& model, const string& array_name);
+Operator* GetOpWithInput(const Model& model, const string& array_name);
+Operator* GetFirstOpWithInput(const Model& model, const string& array_name);
+
+std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
+ const Model& model, const Operator* op);
+std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
+ const Operator* op);
+
+const char* OperatorTypeName(OperatorType type);
+string HelpfulOperatorTypeName(const Operator& op);
+
+void DumpGraphvizVideoFrame(const Model& model);
+void LogDump(int log_level, const string& message, const Model& model);
+void LogSummary(int log_level, const string& message, const Model& model);
+
+inline bool ParseFromStringOverload(const std::string& in,
+ TFLITE_PROTO_NS::Message* proto) {
+ return TFLITE_PROTO_NS::TextFormat::ParseFromString(in, proto);
+}
+
+template <typename Proto>
+bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents,
+ Proto* proto) {
+ if (proto->ParseFromString(input_file_contents)) {
+ return true;
+ }
+
+ if (ParseFromStringOverload(input_file_contents, proto)) {
+ return true;
+ }
+
+ return false;
+}
+
+// TODO(b/36075966): Clean up when dims superseded by array shape.
+void ExtendShape(Shape* shape, int new_shape_size);
+
+// TODO(b/36075966): Clean up when dims superseded by array shape.
+void UnextendShape(Shape* shape, int new_shape_size);
+
+// Checks (using CHECK) that all dimensions of 'shape' are at least 1.
+void CheckShapeDimensions(const Shape& shape);
+
+// Given two shapes with potentially different dimensionality and dimension
+// arrays d0 and d1. Without loss of generality, assume that shape0 may have
+// higher dimensionality (length(d0) >= length(d1)). Then shape0 and shape1
+// "agree up to broadcasting" if:
+// - When walking the d0 and d1 from back to front with indices i0, i1,
+// d0[i0] == d1[i1] or d0[i0] == 1 or d1[i1] == 1, for each dimension until
+// i1 == 0 (inclusive).
+bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1);
+
+// A stricter constraint than ShapesAgreeUpToBroadcasting().
+//
+// Given two shapes with potentially different dimensionality and dimension
+// arrays d0 and d1. Without loss of generality, assume that shape0 may have
+// higher dimensionality (length(d0) >= length(d1)). Then shape0 and shape1
+// "agree up to extending" if:
+// - When walking the d0 and d1 from back to front with indices i0, i1,
+// d0[i0] == d1[i1] for each dimension until i1 == 0 (inclusive).
+// - For the remaining indices [0..i0), d0[i0] == 1.
+bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1);
+
+bool IsArrayFullyConnectedWeights(const Model& model, const string& name);
+
+// If there is a wildcard dimension (-1), this may return a negative value.
+int RequiredBufferSizeForShape(const Shape& shape);
+
+bool IsConstantParameterArray(const Model& model, const string& name);
+
+void CheckNoMissingArray(const Model& model);
+void CheckInvariants(const Model& model);
+
+void CheckModelCounts(const Model& model);
+
+void FixOperatorOrdering(Model* model);
+void FixNoMissingArray(Model* model);
+void FixNoOrphanedArray(Model* model);
+
+void ResolveModelFlags(const ModelFlags& model_flags, Model* model);
+
+template <ArrayDataType A>
+void GetQuantizationParamsFromMinMax(const ModelFlags& model_flags,
+ const MinMax& minmax,
+ QuantizationParams* quantization_params) {
+ using Integer = DataType<A>;
+ const Integer qmin = std::numeric_limits<Integer>::min();
+ const Integer qmax = std::numeric_limits<Integer>::max();
+ const double qmin_double = qmin;
+ const double qmax_double = qmax;
+ const double rmin = minmax.min;
+ const double rmax = minmax.max;
+ // 0 should always be a representable value. Let's assume that the initial
+ // min,max range contains 0.
+ CHECK_LE(rmin, 0.);
+ CHECK_GE(rmax, 0.);
+ if (rmin == rmax) {
+ // Special case where the min,max range is a point. Should be {0}.
+ CHECK_EQ(rmin, 0.);
+ CHECK_EQ(rmax, 0.);
+ quantization_params->zero_point = 0;
+ quantization_params->scale = 0.;
+ return;
+ }
+
+ // General case.
+ //
+ // First determine the scale.
+ const double scale = (rmax - rmin) / (qmax_double - qmin_double);
+
+ // Zero-point computation.
+ // First the initial floating-point computation. The zero-point can be
+ // determined from solving an affine equation for any known pair
+ // (real value, corresponding quantized value).
+ // We know two such pairs: (rmin, qmin) and (rmax, qmax).
+ // The arithmetic error on the zero point computed from either pair
+ // will be roughly machine_epsilon * (sum of absolute values of terms)
+ // so we want to use the variant that adds the smaller terms.
+ const double zero_point_from_min = qmin_double - rmin / scale;
+ const double zero_point_from_max = qmax_double - rmax / scale;
+ const double zero_point_from_min_error =
+ std::abs(qmin_double) + std::abs(rmin / scale);
+ const double zero_point_from_max_error =
+ std::abs(qmax_double) + std::abs(rmax / scale);
+
+ const double zero_point_double =
+ zero_point_from_min_error < zero_point_from_max_error
+ ? zero_point_from_min
+ : zero_point_from_max;
+
+ // Now we need to nudge the zero point to be an integer
+ // (our zero points are integer, and this is motivated by the requirement
+ // to be able to represent the real value "0" exactly as a quantized value,
+ // which is required in multiple places, for example in Im2col with SAME
+ // padding).
+ Integer nudged_zero_point = 0;
+ if (zero_point_double < qmin_double) {
+ nudged_zero_point = qmin;
+ } else if (zero_point_double > qmax_double) {
+ nudged_zero_point = qmax;
+ } else {
+ nudged_zero_point = static_cast<Integer>(std::round(zero_point_double));
+ }
+ // The zero point should always be in the range of quantized value,
+ // [qmin, qmax].
+ CHECK_GE(nudged_zero_point, qmin);
+ CHECK_LE(nudged_zero_point, qmax);
+
+ // Finally, store the result nudged quantization params.
+ quantization_params->zero_point = nudged_zero_point;
+ quantization_params->scale = scale;
+}
+
+void CheckIsReadyForQuantization(const Model& model);
+void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min,
+ double default_ranges_max);
+
+inline int Offset(const Shape& shape, const std::vector<int>& indices) {
+ DCHECK_EQ(shape.dimensions_count(), indices.size());
+ const int dims_count = shape.dimensions_count();
+ int offset = 0;
+ for (int i = 0; i < dims_count; i++) {
+ const int index = indices[i];
+ DCHECK(index >= 0 && index < shape.dims(i));
+ offset *= shape.dims(i);
+ offset += index;
+ }
+ return offset;
+}
+
+inline std::vector<int> ReverseOffset(const Shape& shape, int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, RequiredBufferSizeForShape(shape));
+ const int dims_count = shape.dimensions_count();
+ std::vector<int> indices(dims_count);
+ int residual = index;
+ for (int i = dims_count - 1; i >= 0; i--) {
+ indices[i] = residual % shape.dims(i);
+ residual /= shape.dims(i);
+ }
+ return indices;
+}
+
+int ElementSize(ArrayDataType data_type);
+
+void DropMinMax(Model* model, const string& array_name);
+
+bool IsAllocatableTransientArray(const Model& model, const string& array_name);
+
+void CreateOrCheckRnnStateArray(const string& name, int size, Model* model);
+
+string AvailableArrayName(const Model& model, const string& name);
+
+// Formats a shape as a string: [ dims(0), dims(1), ..., dims(num_dims-1) ].
+string ShapeToString(const Shape& shape);
+
+void PrintArrayShape(Model* model, const string& name);
+
+void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
+ std::vector<int>* out_dims);
+
+bool EstimateArithmeticOpsCount(const Model& model, int64* result);
+
+int AxesCount(AxesOrder axes_order);
+
+void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
+ AxesOrder output_axes_order, Shape* output_shape);
+void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
+ AxesOrder output_axes_order, const Shape& output_shape,
+ const float* input_data, float* output_data);
+
+// Returns true if it may be OK for any graph transformation to ever discard
+// that array. The idea is that we can't ever discard arrays that are either
+// an input or an output of the whole graph, or that appear in RNN back-edges,
+// as that would undercut explicit flags that the user might pass.
+bool IsDiscardableArray(const Model& model, const string& array_name);
+
+void CheckFinalDataTypesSatisfied(const Model& model);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc
new file mode 100644
index 0000000000..22955ce956
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc
@@ -0,0 +1,96 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <tuple>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+enum class Agreement { kBroadcast, kExtend, kBroadcastNotExtend, kNeither };
+
+// A pair of Shapes and whether they should agree up to broadcasting, extending
+// or neither.
+struct ShapePair {
+ Shape left;
+ Shape right;
+ Agreement agreement;
+};
+
+std::vector<ShapePair> CreateShapePairs() {
+ return std::vector<ShapePair>(
+ {// These agree up to broadcast.
+ {Shape({3}), Shape({3}), Agreement::kBroadcast},
+ {Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kBroadcast},
+ {Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcast},
+ {Shape({8, 1, 6, 1}), Shape({7, 1, 5}), Agreement::kBroadcast},
+
+ // These extend (and therefore broadcast).
+ {Shape({3}), Shape({3}), Agreement::kExtend},
+ {Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kExtend},
+ {Shape({1, 1, 3}), Shape({1, 1, 3}), Agreement::kExtend},
+ {Shape({1, 1, 3}), Shape({3}), Agreement::kExtend},
+ {Shape({1, 1, 3}), Shape({1, 3}), Agreement::kExtend},
+
+ // These strictly broadcast and do not extend.
+ {Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcastNotExtend},
+ {Shape({5, 4}), Shape({1}), Agreement::kBroadcastNotExtend},
+ {Shape({5, 4}), Shape({4}), Agreement::kBroadcastNotExtend},
+ {Shape({15, 3, 5}), Shape({15, 1, 5}), Agreement::kBroadcastNotExtend},
+ {Shape({15, 3, 5}), Shape({3, 5}), Agreement::kBroadcastNotExtend},
+ {Shape({15, 3, 5}), Shape({3, 1}), Agreement::kBroadcastNotExtend},
+
+ // These do not broadcast (and therefore also do not extend).
+ {Shape({3}), Shape({4}), Agreement::kNeither},
+ {Shape({2, 1}), Shape({8, 4, 3}), Agreement::kNeither}});
+}
+
+// ShapeTest is an empty parameterized test fixture since there is no state.
+class ShapeTest : public ::testing::TestWithParam<ShapePair> {};
+
+TEST_P(ShapeTest, Agrees) {
+ const ShapePair& param = GetParam();
+
+ switch (param.agreement) {
+ case Agreement::kBroadcast: {
+ EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right));
+ break;
+ }
+ case Agreement::kExtend: {
+ EXPECT_TRUE(ShapesAgreeUpToExtending(param.left, param.right));
+ // Anything that extends should also broadcast.
+ EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right));
+ break;
+ }
+ case Agreement::kBroadcastNotExtend: {
+ // Verify that it strictly broadcasts but does not extend.
+ EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right));
+ EXPECT_FALSE(ShapesAgreeUpToExtending(param.left, param.right));
+ break;
+ }
+ case Agreement::kNeither: {
+ EXPECT_FALSE(ShapesAgreeUpToExtending(param.left, param.right));
+ EXPECT_FALSE(ShapesAgreeUpToBroadcasting(param.left, param.right));
+ break;
+ }
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(AgreeBroadcast, ShapeTest,
+ ::testing::ValuesIn(CreateShapePairs()));
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD
new file mode 100644
index 0000000000..2d918fd4e8
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/BUILD
@@ -0,0 +1,60 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+cc_binary(
+ name = "generate_op_registrations",
+ srcs = ["gen_op_registration_main.cc"],
+ deps = [
+ "//tensorflow/contrib/lite/tools:gen_op_registration",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "gen_op_registration",
+ srcs = ["gen_op_registration.cc"],
+ hdrs = ["gen_op_registration.h"],
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "@com_googlesource_code_re2//:re2",
+ ],
+)
+
+cc_test(
+ name = "gen_op_registration_test",
+ srcs = ["gen_op_registration_test.cc"],
+ data = [
+ "//tensorflow/contrib/lite:testdata/0_subgraphs.bin",
+ "//tensorflow/contrib/lite:testdata/2_subgraphs.bin",
+ "//tensorflow/contrib/lite:testdata/empty_model.bin",
+ "//tensorflow/contrib/lite:testdata/test_model.bin",
+ "//tensorflow/contrib/lite:testdata/test_model_broken.bin",
+ ],
+ deps = [
+ ":gen_op_registration",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "mutable_op_resolver",
+ srcs = ["mutable_op_resolver.cc"],
+ hdrs = ["mutable_op_resolver.h"],
+ deps = ["//tensorflow/contrib/lite:framework"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/tools/gen_op_registration.cc b/tensorflow/contrib/lite/tools/gen_op_registration.cc
new file mode 100644
index 0000000000..57c2567e3b
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/gen_op_registration.cc
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string>
+#include <vector>
+
+#include "third_party/re2/re2.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+
+string NormalizeCustomOpName(const string& op) {
+ string method(op);
+ RE2::GlobalReplace(&method, "([a-z])([A-Z])", "\\1_\\2");
+ std::transform(method.begin(), method.end(), method.begin(), ::toupper);
+ return method;
+}
+
+void ReadOpsFromModel(const ::tflite::Model* model,
+ std::vector<string>* builtin_ops,
+ std::vector<string>* custom_ops) {
+ if (!model) return;
+ auto opcodes = model->operator_codes();
+ if (!opcodes) return;
+ for (const auto* opcode : *opcodes) {
+ if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) {
+ builtin_ops->push_back(
+ tflite::EnumNameBuiltinOperator(opcode->builtin_code()));
+ } else {
+ custom_ops->push_back(opcode->custom_code()->c_str());
+ }
+ }
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/gen_op_registration.h b/tensorflow/contrib/lite/tools/gen_op_registration.h
new file mode 100644
index 0000000000..363bb2335c
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/gen_op_registration.h
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_
+
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+
+// Convert the custom op name to registration name following the convention.
+// Example:
+// "custom_op" -> "CUSTOM_OP"
+// "CustomOp" -> "CUSTOM_OP"
+// Note "Register_" suffix will be added later in the tool.
+string NormalizeCustomOpName(const string& op);
+
+// Read ops from the TFLite model.
+// Enum name of builtin ops will be stored, such as "CONV_2D".
+// Custom op name will be stored as it is.
+void ReadOpsFromModel(const ::tflite::Model* model,
+ std::vector<string>* builtin_ops,
+ std::vector<string>* custom_ops);
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_
diff --git a/tensorflow/contrib/lite/tools/gen_op_registration_main.cc b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc
new file mode 100644
index 0000000000..7b27066a21
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc
@@ -0,0 +1,98 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <fstream>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/tools/gen_op_registration.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+using tensorflow::Flag;
+using tensorflow::Flags;
+
+namespace {
+
+void GenerateFileContent(const string& filename,
+ const std::vector<string>& builtin_ops,
+ const std::vector<string>& custom_ops) {
+ std::ofstream fout(filename);
+
+ fout << "#include "
+ "\"third_party/tensorflow/contrib/lite/model.h\"\n";
+ fout << "#include "
+ "\"third_party/tensorflow/contrib/lite/tools/mutable_op_resolver.h\"\n";
+ fout << "namespace tflite {\n";
+ fout << "namespace ops {\n";
+ if (!builtin_ops.empty()) {
+ fout << "namespace builtin {\n";
+ fout << "// Forward-declarations for the builtin ops.\n";
+ for (const auto& op : builtin_ops) {
+ fout << "TfLiteRegistration* Register_" << op << "();\n";
+ }
+ fout << "} // namespace builtin\n";
+ }
+
+ if (!custom_ops.empty()) {
+ fout << "namespace custom {\n";
+ fout << "// Forward-declarations for the custom ops.\n";
+ for (const auto& op : custom_ops) {
+ fout << "TfLiteRegistration* Register_"
+ << ::tflite::NormalizeCustomOpName(op) << "();\n";
+ }
+ fout << "} // namespace custom\n";
+ }
+ fout << "} // namespace ops\n";
+ fout << "} // namespace tflite\n";
+
+ fout << "void RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {\n";
+ for (const auto& op : builtin_ops) {
+ fout << " resolver->AddBuiltin(::tflite::BuiltinOperator_" << op
+ << ", ::tflite::ops::builtin::Register_" << op << "());\n";
+ }
+ for (const auto& op : custom_ops) {
+ fout << " resolver->AddCustom(\"" << op
+ << "\", ::tflite::ops::custom::Register_"
+ << ::tflite::NormalizeCustomOpName(op) << "());\n";
+ }
+ fout << "}\n";
+ fout.close();
+}
+} // namespace
+
+int main(int argc, char** argv) {
+ string input_model;
+ string output_registration;
+ std::vector<tensorflow::Flag> flag_list = {
+ Flag("input_model", &input_model, "path to the tflite model"),
+ Flag("output_registration", &output_registration,
+ "filename for generated registration code"),
+ };
+ Flags::Parse(&argc, argv, flag_list);
+
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+ std::vector<string> builtin_ops;
+ std::vector<string> custom_ops;
+
+ std::ifstream fin(input_model);
+ std::stringstream content;
+ content << fin.rdbuf();
+ const ::tflite::Model* model = ::tflite::GetModel(content.str().data());
+ ::tflite::ReadOpsFromModel(model, &builtin_ops, &custom_ops);
+ GenerateFileContent(output_registration, builtin_ops, custom_ops);
+ return 0;
+}
diff --git a/tensorflow/contrib/lite/tools/gen_op_registration_test.cc b/tensorflow/contrib/lite/tools/gen_op_registration_test.cc
new file mode 100644
index 0000000000..c65cffe340
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/gen_op_registration_test.cc
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/gen_op_registration.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+using ::testing::ElementsAreArray;
+
+namespace tflite {
+
+class GenOpRegistrationTest : public ::testing::Test {
+ protected:
+ GenOpRegistrationTest() {}
+
+ void ReadOps(const string& model_path) {
+ auto model = FlatBufferModel::BuildFromFile(model_path.data());
+ if (model) {
+ ReadOpsFromModel(model->GetModel(), &builtin_ops_, &custom_ops_);
+ }
+ }
+
+ std::vector<string> builtin_ops_;
+ std::vector<string> custom_ops_;
+};
+
+TEST_F(GenOpRegistrationTest, TestNonExistantFiles) {
+ ReadOps("/tmp/tflite_model_1234");
+ EXPECT_EQ(builtin_ops_.size(), 0);
+ EXPECT_EQ(custom_ops_.size(), 0);
+}
+
+TEST_F(GenOpRegistrationTest, TestModels) {
+ ReadOps("third_party/tensorflow/contrib/lite/testdata/test_model.bin");
+ EXPECT_THAT(builtin_ops_, ElementsAreArray({"CONV_2D"}));
+ EXPECT_THAT(custom_ops_, ElementsAreArray({"testing_op"}));
+}
+
+TEST_F(GenOpRegistrationTest, TestEmptyModels) {
+ ReadOps("third_party/tensorflow/contrib/lite/testdata/empty_model.bin");
+ EXPECT_EQ(builtin_ops_.size(), 0);
+ EXPECT_EQ(custom_ops_.size(), 0);
+}
+
+TEST_F(GenOpRegistrationTest, TestZeroSubgraphs) {
+ ReadOps("third_party/tensorflow/contrib/lite/testdata/0_subgraphs.bin");
+ EXPECT_EQ(builtin_ops_.size(), 0);
+ EXPECT_EQ(custom_ops_.size(), 0);
+}
+
+TEST_F(GenOpRegistrationTest, TestBrokenMmap) {
+ ReadOps("third_party/tensorflow/contrib/lite/testdata/test_model_broken.bin");
+ EXPECT_EQ(builtin_ops_.size(), 0);
+ EXPECT_EQ(custom_ops_.size(), 0);
+}
+
+TEST_F(GenOpRegistrationTest, TestNormalizeCustomOpName) {
+ std::vector<std::pair<string, string>> testcase = {
+ {"CustomOp", "CUSTOM_OP"},
+ {"a", "A"},
+ {"custom_op", "CUSTOM_OP"},
+ {"customop", "CUSTOMOP"},
+ };
+
+ for (const auto& test : testcase) {
+ EXPECT_EQ(NormalizeCustomOpName(test.first), test.second);
+ }
+}
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: FLAGS_logtostderr = true;
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.cc b/tensorflow/contrib/lite/tools/mutable_op_resolver.cc
new file mode 100644
index 0000000000..8a921d7c5a
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.cc
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
+
+namespace tflite {
+
+TfLiteRegistration* MutableOpResolver::FindOp(
+ tflite::BuiltinOperator op) const {
+ auto it = builtins_.find(op);
+ return it != builtins_.end() ? it->second : nullptr;
+}
+
+TfLiteRegistration* MutableOpResolver::FindOp(const char* op) const {
+ auto it = custom_ops_.find(op);
+ return it != custom_ops_.end() ? it->second : nullptr;
+}
+
+void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
+ TfLiteRegistration* registration) {
+ registration->builtin_code = op;
+ builtins_.insert(std::make_pair(op, registration));
+}
+
+void MutableOpResolver::AddCustom(const char* name,
+ TfLiteRegistration* registration) {
+ registration->builtin_code = BuiltinOperator_CUSTOM;
+ custom_ops_.insert(std::make_pair(std::string(name), registration));
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.h b/tensorflow/contrib/lite/tools/mutable_op_resolver.h
new file mode 100644
index 0000000000..9546c32427
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.h
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
+
+#include <unordered_map>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+
+// An OpResolver that is mutable, also used as the op in gen_op_registration.
+// A typical usage:
+// MutableOpResolver resolver;
+// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
+// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
+// InterpreterBuilder(model, resolver)(&interpreter);
+class MutableOpResolver : public OpResolver {
+ public:
+ MutableOpResolver() {}
+ TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override;
+ TfLiteRegistration* FindOp(const char* op) const override;
+ void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration);
+ void AddCustom(const char* name, TfLiteRegistration* registration);
+
+ private:
+ std::unordered_map<tflite::BuiltinOperator, TfLiteRegistration*> builtins_;
+ std::unordered_map<std::string, TfLiteRegistration*> custom_ops_;
+};
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/version.h b/tensorflow/contrib/lite/version.h
new file mode 100644
index 0000000000..a751afabe7
--- /dev/null
+++ b/tensorflow/contrib/lite/version.h
@@ -0,0 +1,23 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_
+
+// The version number of the Schema. Ideally all changes will be backward
+// compatible. If that ever changes, we must ensure that version is the first
+// entry in the new tflite root so that we can see that version is not 1.
+#define TFLITE_SCHEMA_VERSION (3)
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 8b77c99cb5..5f06106c1d 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -8,6 +8,7 @@ tensorflow/core/kernels/xent_op.cc
tensorflow/core/kernels/where_op.cc
tensorflow/core/kernels/variable_ops.cc
tensorflow/core/kernels/unpack_op.cc
+tensorflow/core/kernels/unique_op.cc
tensorflow/core/kernels/transpose_op.cc
tensorflow/core/kernels/transpose_functor_cpu.cc
tensorflow/core/kernels/training_op_helpers.cc
@@ -41,6 +42,9 @@ tensorflow/core/kernels/spectrogram_op.cc
tensorflow/core/kernels/spectrogram.cc
tensorflow/core/kernels/sparse_to_dense_op.cc
tensorflow/core/kernels/sparse_matmul_op.cc
+tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
+tensorflow/core/kernels/sparse_reshape_op.c
+tensorflow/core/kernels/segment_reduction_ops.cc
tensorflow/core/kernels/softsign_op.cc
tensorflow/core/kernels/softplus_op.cc
tensorflow/core/kernels/softmax_op.cc
@@ -109,6 +113,10 @@ tensorflow/core/kernels/maxpooling_op.cc
tensorflow/core/kernels/matmul_op.cc
tensorflow/core/kernels/lrn_op.cc
tensorflow/core/kernels/logging_ops.cc
+tensorflow/core/kernels/initializable_lookup_table.c
+tensorflow/core/kernels/lookup_table_init_op.cc
+tensorflow/core/kernels/lookup_table_op.cc
+tensorflow/core/kernels/lookup_util.cc
tensorflow/core/kernels/inplace_ops.cc
tensorflow/core/kernels/in_topk_op.cc
tensorflow/core/kernels/immutable_constant_op.cc
@@ -116,10 +124,18 @@ tensorflow/core/kernels/identity_op.cc
tensorflow/core/kernels/identity_n_op.cc
tensorflow/core/kernels/gather_op.cc
tensorflow/core/kernels/gather_functor.cc
+tensorflow/core/kernels/gather_nd_op.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc
tensorflow/core/kernels/fused_batch_norm_op.cc
tensorflow/core/kernels/function_ops.cc
tensorflow/core/kernels/fill_functor.cc
tensorflow/core/kernels/fifo_queue.cc
+tensorflow/core/kernels/fifo_queue_op.cc
tensorflow/core/kernels/fake_quant_ops.cc
tensorflow/core/kernels/example_parsing_ops.cc
tensorflow/core/kernels/encode_wav_op.cc
@@ -166,6 +182,8 @@ tensorflow/core/kernels/cwise_op_floor.cc
tensorflow/core/kernels/cwise_op_exp.cc
tensorflow/core/kernels/cwise_op_equal_to_2.cc
tensorflow/core/kernels/cwise_op_equal_to_1.cc
+tensorflow/core/kernels/cwise_op_not_equal_to_2.cc
+tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
tensorflow/core/kernels/cwise_op_div.cc
tensorflow/core/kernels/cwise_op_bitwise_xor.cc
tensorflow/core/kernels/cwise_op_bitwise_or.cc
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py
index 302042c4dd..8eed45c4b3 100644
--- a/tensorflow/contrib/metrics/__init__.py
+++ b/tensorflow/contrib/metrics/__init__.py
@@ -27,6 +27,7 @@ See the @{$python/contrib.metrics} guide.
@@streaming_false_negative_rate
@@streaming_false_negative_rate_at_thresholds
@@streaming_auc
+@@streaming_dynamic_auc
@@streaming_curve_points
@@streaming_recall_at_k
@@streaming_mean_absolute_error
@@ -88,6 +89,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points
+from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_dynamic_auc
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate_at_thresholds
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 3dd1f1a627..24692ff12f 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -1178,6 +1178,154 @@ def streaming_auc(predictions,
name=name)
+def _compute_dynamic_auc(labels, predictions, curve='ROC'):
+ """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
+
+ Computes the area under the ROC or PR curve using each prediction as a
+ threshold. This could be slow for large batches, but has the advantage of not
+ having its results degrade depending on the distribution of predictions.
+
+ Args:
+ labels: A `Tensor` of ground truth labels with the same shape as
+ `predictions` with values of 0 or 1 and type `int64`.
+ predictions: A 1-D `Tensor` of predictions whose values are `float64`.
+ curve: The name of the curve to be computed, 'ROC' for the Receiving
+ Operating Characteristic or 'PR' for the Precision-Recall curve.
+
+ Returns:
+ A scalar `Tensor` containing the area-under-curve value for the input.
+ """
+ # Count the total number of positive and negative labels in the input.
+ size = array_ops.size(predictions)
+ total_positive = math_ops.cast(math_ops.reduce_sum(labels), dtypes.int32)
+
+ def continue_computing_dynamic_auc():
+ """Continues dynamic auc computation, entered if labels are not all equal.
+
+ Returns:
+ A scalar `Tensor` containing the area-under-curve value.
+ """
+ # Sort the predictions descending, and the corresponding labels as well.
+ ordered_predictions, indices = nn.top_k(predictions, k=size)
+ ordered_labels = array_ops.gather(labels, indices)
+
+ # Get the counts of the unique ordered predictions.
+ _, _, counts = array_ops.unique_with_counts(ordered_predictions)
+
+ # Compute the indices of the split points between different predictions.
+ splits = math_ops.cast(
+ array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32)
+
+ # Count the positives to the left of the split indices.
+ positives = math_ops.cast(
+ array_ops.pad(math_ops.cumsum(ordered_labels), paddings=[[1, 0]]),
+ dtypes.int32)
+ true_positives = array_ops.gather(positives, splits)
+ if curve == 'ROC':
+ # Count the negatives to the left of every split point and the total
+ # number of negatives for computing the FPR.
+ false_positives = math_ops.subtract(splits, true_positives)
+ total_negative = size - total_positive
+ x_axis_values = math_ops.truediv(false_positives, total_negative)
+ y_axis_values = math_ops.truediv(true_positives, total_positive)
+ elif curve == 'PR':
+ x_axis_values = math_ops.truediv(true_positives, total_positive)
+ # For conformance, set precision to 1 when the number of positive
+ # classifications is 0.
+ y_axis_values = array_ops.where(
+ math_ops.greater(splits, 0),
+ math_ops.truediv(true_positives, splits),
+ array_ops.ones_like(true_positives, dtype=dtypes.float64))
+
+ # Calculate trapezoid areas.
+ heights = math_ops.add(y_axis_values[1:], y_axis_values[:-1]) / 2.0
+ widths = math_ops.abs(
+ math_ops.subtract(x_axis_values[1:], x_axis_values[:-1]))
+ return math_ops.reduce_sum(math_ops.multiply(heights, widths))
+
+ # If all the labels are the same, AUC isn't well-defined (but raising an
+ # exception seems excessive) so we return 0, otherwise we finish computing.
+ return control_flow_ops.cond(
+ math_ops.logical_or(
+ math_ops.equal(total_positive, 0),
+ math_ops.equal(total_positive, size)
+ ),
+ true_fn=lambda: array_ops.constant(0, dtypes.float64),
+ false_fn=continue_computing_dynamic_auc)
+
+
+def streaming_dynamic_auc(labels,
+ predictions,
+ curve='ROC',
+ metrics_collections=(),
+ updates_collections=(),
+ name=None):
+ """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
+
+ USAGE NOTE: this approach requires storing all of the predictions and labels
+ for a single evaluation in memory, so it may not be usable when the evaluation
+ batch size and/or the number of evaluation steps is very large.
+
+ Computes the area under the ROC or PR curve using each prediction as a
+ threshold. This has the advantage of being resilient to the distribution of
+ predictions by aggregating across batches, accumulating labels and predictions
+ and performing the final calculation using all of the concatenated values.
+
+ Args:
+ labels: A `Tensor` of ground truth labels with the same shape as `labels`
+ and with values of 0 or 1 whose values are castable to `int64`.
+ predictions: A `Tensor` of predictions whose values are castable to
+ `float64`. Will be flattened into a 1-D `Tensor`.
+ curve: The name of the curve for which to compute AUC, 'ROC' for the
+ Receiving Operating Characteristic or 'PR' for the Precision-Recall curve.
+ metrics_collections: An optional iterable of collections that `auc` should
+ be added to.
+ updates_collections: An optional iterable of collections that `update_op`
+ should be added to.
+ name: An optional name for the variable_scope that contains the metric
+ variables.
+
+ Returns:
+ auc: A scalar `Tensor` containing the current area-under-curve value.
+ update_op: An operation that concatenates the input labels and predictions
+ to the accumulated values.
+
+ Raises:
+ ValueError: If `labels` and `predictions` have mismatched shapes or if
+ `curve` isn't a recognized curve type.
+ """
+
+ if curve not in ['PR', 'ROC']:
+ raise ValueError('curve must be either ROC or PR, %s unknown' % curve)
+
+ with variable_scope.variable_scope(name, default_name='dynamic_auc'):
+ labels.get_shape().assert_is_compatible_with(predictions.get_shape())
+ predictions = array_ops.reshape(
+ math_ops.cast(predictions, dtypes.float64), [-1])
+ labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1])
+ with ops.control_dependencies([
+ check_ops.assert_greater_equal(
+ labels,
+ array_ops.zeros_like(labels, dtypes.int64),
+ message='labels must be 0 or 1, at least one is <0'),
+ check_ops.assert_less_equal(
+ labels,
+ array_ops.ones_like(labels, dtypes.int64),
+ message='labels must be 0 or 1, at least one is >1')
+ ]):
+ preds_accum, update_preds = streaming_concat(predictions,
+ name='concat_preds')
+ labels_accum, update_labels = streaming_concat(labels,
+ name='concat_labels')
+ update_op = control_flow_ops.group(update_labels, update_preds)
+ auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve)
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update_op)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, auc)
+ return auc, update_op
+
+
def streaming_precision_recall_at_equal_thresholds(predictions,
labels,
num_thresholds=None,
@@ -3285,6 +3433,7 @@ __all__ = [
'streaming_accuracy',
'streaming_auc',
'streaming_curve_points',
+ 'streaming_dynamic_auc',
'streaming_false_negative_rate',
'streaming_false_negative_rate_at_thresholds',
'streaming_false_negatives',
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 6a8e58b4da..5d0463e1f7 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -1708,6 +1708,34 @@ class StreamingCurvePointsTest(test.TestCase):
[[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]])
+def _np_auc(predictions, labels, weights=None):
+ """Computes the AUC explicitly using Numpy.
+
+ Args:
+ predictions: an ndarray with shape [N].
+ labels: an ndarray with shape [N].
+ weights: an ndarray with shape [N].
+
+ Returns:
+ the area under the ROC curve.
+ """
+ if weights is None:
+ weights = np.ones(np.size(predictions))
+ is_positive = labels > 0
+ num_positives = np.sum(weights[is_positive])
+ num_negatives = np.sum(weights[~is_positive])
+
+ # Sort descending:
+ inds = np.argsort(-predictions)
+
+ sorted_labels = labels[inds]
+ sorted_weights = weights[inds]
+ is_positive = sorted_labels > 0
+
+ tp = np.cumsum(sorted_weights * is_positive) / num_positives
+ return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives
+
+
class StreamingAUCTest(test.TestCase):
def setUp(self):
@@ -1896,33 +1924,6 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(1, auc.eval(), 6)
- def np_auc(self, predictions, labels, weights):
- """Computes the AUC explicitly using Numpy.
-
- Args:
- predictions: an ndarray with shape [N].
- labels: an ndarray with shape [N].
- weights: an ndarray with shape [N].
-
- Returns:
- the area under the ROC curve.
- """
- if weights is None:
- weights = np.ones(np.size(predictions))
- is_positive = labels > 0
- num_positives = np.sum(weights[is_positive])
- num_negatives = np.sum(weights[~is_positive])
-
- # Sort descending:
- inds = np.argsort(-predictions)
-
- sorted_labels = labels[inds]
- sorted_weights = weights[inds]
- is_positive = sorted_labels > 0
-
- tp = np.cumsum(sorted_weights * is_positive) / num_positives
- return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives
-
def testWithMultipleUpdates(self):
num_samples = 1000
batch_size = 10
@@ -1945,7 +1946,7 @@ class StreamingAUCTest(test.TestCase):
for weights in (None, np.ones(num_samples), np.random.exponential(
scale=1.0, size=num_samples)):
- expected_auc = self.np_auc(predictions, labels, weights)
+ expected_auc = _np_auc(predictions, labels, weights)
with self.test_session() as sess:
enqueue_ops = [[] for i in range(num_batches)]
@@ -1974,6 +1975,211 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(expected_auc, auc.eval(), 2)
+class StreamingDynamicAUCTest(test.TestCase):
+
+ def setUp(self):
+ super(StreamingDynamicAUCTest, self).setUp()
+ np.random.seed(1)
+ ops.reset_default_graph()
+
+ def testUnknownCurve(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'curve must be either ROC or PR, TEST_CURVE unknown'):
+ metrics.streaming_dynamic_auc(labels=array_ops.ones((10, 1)),
+ predictions=array_ops.ones((10, 1)),
+ curve='TEST_CURVE')
+
+ def testVars(self):
+ metrics.streaming_dynamic_auc(
+ labels=array_ops.ones((10, 1)), predictions=array_ops.ones((10, 1)))
+ _assert_metric_variables(self, ['dynamic_auc/concat_labels/array:0',
+ 'dynamic_auc/concat_labels/size:0',
+ 'dynamic_auc/concat_preds/array:0',
+ 'dynamic_auc/concat_preds/size:0'])
+
+ def testMetricsCollection(self):
+ my_collection_name = '__metrics__'
+ auc, _ = metrics.streaming_dynamic_auc(
+ labels=array_ops.ones((10, 1)),
+ predictions=array_ops.ones((10, 1)),
+ metrics_collections=[my_collection_name])
+ self.assertEqual(ops.get_collection(my_collection_name), [auc])
+
+ def testUpdatesCollection(self):
+ my_collection_name = '__updates__'
+ _, update_op = metrics.streaming_dynamic_auc(
+ labels=array_ops.ones((10, 1)),
+ predictions=array_ops.ones((10, 1)),
+ updates_collections=[my_collection_name])
+ self.assertEqual(ops.get_collection(my_collection_name), [update_op])
+
+ def testValueTensorIsIdempotent(self):
+ predictions = random_ops.random_uniform(
+ (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
+ labels = random_ops.random_uniform(
+ (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
+ auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ # Run several updates.
+ for _ in xrange(10):
+ sess.run(update_op)
+ # Then verify idempotency.
+ initial_auc = auc.eval()
+ for _ in xrange(10):
+ self.assertAlmostEqual(initial_auc, auc.eval(), 5)
+
+ def testAllLabelsOnes(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant([1., 1., 1.])
+ labels = constant_op.constant([1, 1, 1])
+ auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertEqual(0, auc.eval())
+
+ def testAllLabelsZeros(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant([1., 1., 1.])
+ labels = constant_op.constant([0, 0, 0])
+ auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertEqual(0, auc.eval())
+
+ def testNonZeroOnePredictions(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant([2.5, -2.5, 2.5, -2.5],
+ dtype=dtypes_lib.float32)
+ labels = constant_op.constant([1, 0, 1, 0])
+ auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertAlmostEqual(auc.eval(), 1.0)
+
+ def testAllCorrect(self):
+ inputs = np.random.randint(0, 2, size=(100, 1))
+ with self.test_session() as sess:
+ predictions = constant_op.constant(inputs)
+ labels = constant_op.constant(inputs)
+ auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertEqual(1, auc.eval())
+
+ def testSomeCorrect(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant([1, 0, 1, 0])
+ labels = constant_op.constant([0, 1, 1, 0])
+ auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertAlmostEqual(0.5, auc.eval())
+
+ def testAllIncorrect(self):
+ inputs = np.random.randint(0, 2, size=(100, 1))
+ with self.test_session() as sess:
+ predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
+ auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertAlmostEqual(0, auc.eval())
+
+ def testExceptionOnIncompatibleShapes(self):
+ with self.test_session() as sess:
+ predictions = array_ops.ones([5])
+ labels = array_ops.zeros([6])
+ with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
+ _, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+
+ def testExceptionOnGreaterThanOneLabel(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
+ labels = constant_op.constant([2, 1, 0])
+ _, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+ sess.run(variables.local_variables_initializer())
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ '.*labels must be 0 or 1, at least one is >1.*'):
+ sess.run(update_op)
+
+ def testExceptionOnNegativeLabel(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
+ labels = constant_op.constant([1, 0, -1])
+ _, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+ sess.run(variables.local_variables_initializer())
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ '.*labels must be 0 or 1, at least one is <0.*'):
+ sess.run(update_op)
+
+ def testWithMultipleUpdates(self):
+ batch_size = 10
+ num_batches = 100
+ labels = np.array([])
+ predictions = np.array([])
+ tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=dtypes_lib.int32)
+ tf_predictions = variables.Variable(
+ array_ops.ones(batch_size),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=dtypes_lib.float32)
+ auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions)
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ for _ in xrange(num_batches):
+ new_labels = np.random.randint(0, 2, size=batch_size)
+ noise = np.random.normal(0.0, scale=0.2, size=batch_size)
+ new_predictions = 0.4 + 0.2 * new_labels + noise
+ labels = np.concatenate([labels, new_labels])
+ predictions = np.concatenate([predictions, new_predictions])
+ sess.run(tf_labels.assign(new_labels))
+ sess.run(tf_predictions.assign(new_predictions))
+ sess.run(update_op)
+ expected_auc = _np_auc(predictions, labels)
+ self.assertAlmostEqual(expected_auc, auc.eval())
+
+ def testAUCPRReverseIncreasingPredictions(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32)
+ labels = constant_op.constant([0, 0, 1, 1])
+ auc, update_op = metrics.streaming_dynamic_auc(
+ labels, predictions, curve='PR')
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5)
+
+ def testAUCPRJumbledPredictions(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32)
+ labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1])
+ auc, update_op = metrics.streaming_dynamic_auc(
+ labels, predictions, curve='PR')
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6)
+
+ def testAUCPRPredictionsLessThanHalf(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
+ shape=(1, 7),
+ dtype=dtypes_lib.float32)
+ labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
+ auc, update_op = metrics.streaming_dynamic_auc(
+ labels, predictions, curve='PR')
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5)
+
+
class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
def setUp(self):
diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD
index ed9fb64b95..df9dbb457a 100644
--- a/tensorflow/contrib/nccl/BUILD
+++ b/tensorflow/contrib/nccl/BUILD
@@ -48,8 +48,8 @@ tf_cuda_cc_test(
# Disabled on jenkins until errors finding nvmlShutdown are found.
tags = [
"manual",
+ "multi_gpu",
"no_oss",
- "noguitar", # note: is run manually there
"notap",
],
deps = if_cuda(
@@ -138,8 +138,8 @@ cuda_py_test(
# Disabled on jenkins until errors finding nvmlShutdown are found.
tags = [
"manual",
+ "multi_gpu",
"no_oss",
- "noguitar", # note: is run manually there
"notap",
],
)
diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
index 0b13e3595e..bad0abd44c 100644
--- a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
+++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
@@ -72,14 +72,15 @@ class NcclTestCase(test.TestCase):
two.
device_sets: Tuple of virtual devices to run test on.
"""
- if not test.is_gpu_available():
- return # Test requires access to a GPU
-
for dtype in [np.float32, np.int32, np.int64, np.float64]:
# Create session inside outer loop to test use of
# same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess:
+ # Check GPU availability *after* creating test session, see b/68975239.
+ if not test.is_gpu_available():
+ return # Test requires access to a GPU
+
for devices in device_sets:
shape = (3, 4)
random = (np.random.random_sample(shape) - .5) * 1024
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index 935af80e7a..45a98c7f85 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -133,7 +133,6 @@ py_library(
deps = [
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
@@ -143,6 +142,23 @@ py_library(
],
)
+py_test(
+ name = "quant_ops_test",
+ size = "small",
+ srcs = ["python/quant_ops_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":quant_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:session",
+ "//tensorflow/python:variables",
+ ],
+)
+
py_library(
name = "quantize",
srcs = ["python/quantize.py"],
diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md
new file mode 100644
index 0000000000..782232e85f
--- /dev/null
+++ b/tensorflow/contrib/quantize/README.md
@@ -0,0 +1,73 @@
+tf.contrib.quantize provides tools for transforming graphs to include ops to
+model quantization of weights, biases and activations during both training and
+inference. This is done using the
+[fake quantization op]
+(https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/fake_quantization),
+which is described below:
+
+Recent literature has shown that fixed point networks provide comparable
+performance to floating point networks [1]. This is achieved by modeling the
+quantization operation during training in both the forward and backward passes.
+The fake quantization operator achieves this by modeling the quantizer as a pass
+through estimator [2]. Note that during back propagation, the parameters are
+updated at high precision as this is needed to ensure sufficient precision in
+accumulating tiny adjustments to the parameters. However, for the forward pass,
+the parameters and activations are quantized to the desired lower precision.
+
+![drawing](g3doc/drawings/Fake_Quantization.jpg)
+
+###Forward pass
+
+
+
+
+\begin{equation*}
+f_Q(x) = \Delta\text{ }round\left(\frac{sat\left(x\right)-x_{min}}{\Delta}\right)
+\end{equation*}
+
+
+where
+
+$$
+\begin{equation*}
+sat(x) =
+\left\{
+ \begin{array}{ll}
+ x_{min} & \mbox{if } x \le x_{min} \\
+ x & \mbox{if } x_{min} \leq x \leq x_{max} \\
+ x_{max} & \mbox{if } x_{max} \le x
+ \end{array}
+\right.
+\end{equation*}
+$$
+
+
+where $$\Delta$$ is the Quantizer Step size, given by
+$$\Delta =\frac{x_{max} - x_{min} }{255} $$ and $$x_{min} $$ and $$x_{max}$$ are
+the minimum and maximum values of the variable under consideration. Note that
+the rounding performed is deterministic and corresponds to asymmetric rounding,
+which is supported in almost all hardware platforms.
+
+###Backward pass
+For the backward pass, we model the quantizer as a piecewise linear block, with
+derivatives that are non-zero only in the linear region.
+
+
+
+\begin{equation*}
+\frac{df_Q(x)}{dx}=1, x_{min} \leq x \leq x_{max},\text{ 0 elsewhere }
+\end{equation*}
+
+Therefore, the backward pass through the quantizer reduces to passing through
+the gradients as long as the inputs to the quantizer are in the linear region.
+Otherwise, the gradients are set to zero.
+
+Note that the quantizer is fully specified by the min and max values of the
+variables being quantized.
+
+
+[1] P.Gysel, "HARDWARE-ORIENTED APPROXIMATION OF CONVOLUTIONAL
+NEURAL NETWORKS", https://arxiv.org/pdf/1604.03168.pdf
+
+[2] Y.Bengio, "Estimating or Propagating Gradients Through Stochastic Neurons
+for Conditional Computation", https://arxiv.org/abs/1308.3432
diff --git a/tensorflow/contrib/quantize/g3doc/drawings/Fake_Quantization.jpg b/tensorflow/contrib/quantize/g3doc/drawings/Fake_Quantization.jpg
new file mode 100644
index 0000000000..fdc7ae40ce
--- /dev/null
+++ b/tensorflow/contrib/quantize/g3doc/drawings/Fake_Quantization.jpg
Binary files differ
diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py
index 0a38ef9fcd..f80d427ff0 100644
--- a/tensorflow/contrib/quantize/python/quant_ops.py
+++ b/tensorflow/contrib/quantize/python/quant_ops.py
@@ -22,15 +22,12 @@ from tensorflow.contrib.framework.python.ops import add_arg_scope
from tensorflow.contrib.framework.python.ops import model_variable
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import moving_averages
-EPSILON = 1e-5
-
@add_arg_scope
def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None):
@@ -133,12 +130,10 @@ def LastValueQuantize(inputs,
batch_min = inputs
else:
batch_min = math_ops.reduce_min(inputs, name='BatchMin')
- batch_min -= EPSILON
- # B-eng requires that 0.0 if always in the [min; max] range.
+ # TFLite requires that 0.0 if always in the [min; max] range.
batch_min = math_ops.minimum(batch_min, 0.0)
- assign_min_op = state_ops.assign(
- min_var, batch_min, name='AssignMinLast').op
- ops.add_to_collection(updates_collection, assign_min_op)
+ assign_min = state_ops.assign(min_var, batch_min, name='AssignMinLast')
+ ops.add_to_collection(updates_collection, assign_min.op)
if per_channel:
if input_dim >= 2:
@@ -148,17 +143,15 @@ def LastValueQuantize(inputs,
batch_max = inputs
else:
batch_max = math_ops.reduce_max(inputs, name='BatchMax')
- batch_max += EPSILON
- # B-eng requires that 0.0 if always in the [min; max] range.
+ # TFLite requires that 0.0 if always in the [min; max] range.
batch_max = math_ops.maximum(batch_max, 0.0)
- assign_max_op = state_ops.assign(
- max_var, batch_max, name='AssignMaxLast').op
- ops.add_to_collection(updates_collection, assign_max_op)
+ assign_max = state_ops.assign(max_var, batch_max, name='AssignMaxLast')
+ ops.add_to_collection(updates_collection, assign_max.op)
return _FakeQuantWithMinMaxVars(
inputs,
- batch_min,
- batch_max,
+ assign_min,
+ assign_max,
per_channel=per_channel,
num_bits=num_bits,
narrow_range=narrow_range)
@@ -251,9 +244,9 @@ def MovingAvgQuantize(inputs,
batch_min = math_ops.reduce_min(inputs, name='BatchMin')
# B-eng requires that 0.0 if always in the [min; max] range.
batch_min = math_ops.minimum(batch_min, 0.0)
- assign_min_op = moving_averages.assign_moving_average(
- min_var, batch_min, ema_decay, name='AssignMinEma').op
- ops.add_to_collection(updates_collection, assign_min_op)
+ assign_min = moving_averages.assign_moving_average(
+ min_var, batch_min, ema_decay, name='AssignMinEma')
+ ops.add_to_collection(updates_collection, assign_min.op)
if per_channel:
if input_dim >= 2:
@@ -265,14 +258,14 @@ def MovingAvgQuantize(inputs,
batch_max = math_ops.reduce_max(inputs, name='BatchMax')
# B-eng requires that 0.0 if always in the [min; max] range.
batch_max = math_ops.maximum(batch_max, 0.0)
- assign_max_op = moving_averages.assign_moving_average(
- max_var, batch_max, ema_decay, name='AssignMaxEma').op
- ops.add_to_collection(updates_collection, assign_max_op)
+ assign_max = moving_averages.assign_moving_average(
+ max_var, batch_max, ema_decay, name='AssignMaxEma')
+ ops.add_to_collection(updates_collection, assign_max.op)
return _FakeQuantWithMinMaxVars(
inputs,
- min_var,
- max_var,
+ assign_min,
+ assign_max,
per_channel=per_channel,
num_bits=num_bits,
narrow_range=narrow_range)
@@ -301,20 +294,10 @@ def _FakeQuantWithMinMaxVars(inputs, min_var, max_var, per_channel, num_bits,
if per_channel:
assert len(min_var.get_shape()) == 1
assert len(max_var.get_shape()) == 1
- with ops.control_dependencies([check_ops.assert_less(min_var, max_var)]):
- return array_ops.fake_quant_with_min_max_vars_per_channel(
- inputs,
- min_var,
- max_var,
- num_bits=num_bits,
- narrow_range=narrow_range)
+ return array_ops.fake_quant_with_min_max_vars_per_channel(
+ inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range)
else:
assert min_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison
assert max_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison
- with ops.control_dependencies([check_ops.assert_less(min_var, max_var)]):
- return array_ops.fake_quant_with_min_max_vars(
- inputs,
- min_var,
- max_var,
- num_bits=num_bits,
- narrow_range=narrow_range)
+ return array_ops.fake_quant_with_min_max_vars(
+ inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range)
diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py
new file mode 100644
index 0000000000..3884679602
--- /dev/null
+++ b/tensorflow/contrib/quantize/python/quant_ops_test.py
@@ -0,0 +1,87 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for third_party.tensorflow.contrib.quantize.python.quant_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.quantize.python import quant_ops
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+
+_MIN_MAX_VARS = 'min_max_vars'
+
+
+class QuantOpsTest(googletest.TestCase):
+
+ def testLastValueQuantizeTrainingAssign(self):
+ g = ops.Graph()
+ with session.Session(graph=g) as sess:
+ x = array_ops.placeholder(dtypes.float32, shape=[2])
+ y = quant_ops.LastValueQuantize(
+ x,
+ init_min=0.0,
+ init_max=0.0,
+ is_training=True,
+ vars_collection=_MIN_MAX_VARS)
+
+ # Run the step.
+ sess.run(variables.global_variables_initializer())
+ sess.run(y, feed_dict={x: [-1.0, 1.0]})
+ # Now check that the min_max_vars were, in fact, updated.
+ min_value, max_value = self._GetMinMaxValues(sess)
+ self.assertEqual(min_value, -1.0)
+ self.assertEqual(max_value, 1.0)
+
+ def testMovingAvgQuantizeTrainingAssign(self):
+ g = ops.Graph()
+ with session.Session(graph=g) as sess:
+ x = array_ops.placeholder(dtypes.float32, shape=[2])
+ y = quant_ops.MovingAvgQuantize(
+ x,
+ init_min=0.0,
+ init_max=0.0,
+ is_training=True,
+ vars_collection=_MIN_MAX_VARS)
+
+ # Run the step.
+ sess.run(variables.global_variables_initializer())
+ # Do two runs to avoid zero debias.
+ sess.run(y, feed_dict={x: [-1.0, 1.0]})
+ sess.run(y, feed_dict={x: [0.0, 0.0]})
+ # Now check that the min_max_vars were, in fact, updated.
+ min_value, max_value = self._GetMinMaxValues(sess)
+ self.assertGreater(min_value, -1.0)
+ self.assertLess(min_value, 0.0)
+ self.assertGreater(max_value, 0.0)
+ self.assertLess(max_value, 1.0)
+
+ def _GetMinMaxValues(self, sess):
+ min_max_vars = ops.get_collection(_MIN_MAX_VARS)
+ self.assertEqual(len(min_max_vars), 2)
+ min_idx = 0 if 'min' in min_max_vars[0].name else 1
+ max_idx = (min_idx + 1) % 2
+ min_var, max_var = min_max_vars[min_idx], min_max_vars[max_idx]
+ min_max_values = sess.run([min_var, max_var])
+ return min_max_values[0], min_max_values[1]
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 548e33663e..7db2d863aa 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -89,8 +89,8 @@ def Quantize(graph,
op.name[:-len('/depthwise')])
if separable_conv and separable_conv.type == 'Conv2D':
continue
- if op.type == 'Conv2D':
- # Quantize add ops that come after Conv2D
+ # Quantize add ops that come after Conv2D or DepthwiseConv2dNative.
+ if op.type in ['Conv2D', 'DepthwiseConv2dNative']:
add_context_re = re.search(r'^(.*)/[^/]+/', op.name)
if add_context_re is not None:
context.add_contexts.add(add_context_re.group(1))
@@ -387,7 +387,7 @@ class _QuantizeContext(object):
if delay_requested and self.quant_delay and self.quant_delay > 0:
activate_quant = math_ops.greater_equal(
- training_util.get_global_step(),
+ training_util.get_or_create_global_step(),
self.quant_delay,
name=scope + '/activate_quant')
quant = control_flow_ops.cond(
diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
index 3e62f95bd6..57dab03f16 100644
--- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
@@ -97,8 +97,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(weights_quant.type, quantization_node_name)
expected_inputs = [
- scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum',
- scope + '/weights/read'
+ scope + '/weights_quant/AssignMinLast',
+ scope + '/weights_quant/AssignMaxLast', scope + '/weights/read'
]
self._AssertInputOpsAre(weights_quant, expected_inputs)
output_op_name = scope + '/Conv2D'
@@ -109,8 +109,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
expected_inputs = [
- scope + '/conv_quant/min/read', scope + '/conv_quant/max/read',
- scope + '/BiasAdd'
+ scope + '/conv_quant/AssignMinEma',
+ scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd'
]
self._AssertInputOpsAre(conv_quant, expected_inputs)
output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
@@ -122,7 +122,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self.assertEqual(act_quant.type, quantization_node_name)
expected_inputs = [
- 'test/act_quant/min/read', 'test/act_quant/max/read',
+ 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
'test/' + activation_op_name
]
self._AssertInputOpsAre(act_quant, expected_inputs)
@@ -172,8 +172,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(weights_quant.type, quantization_node_name)
expected_inputs = [
- scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum',
- scope + '/weights/read'
+ scope + '/weights_quant/AssignMinLast',
+ scope + '/weights_quant/AssignMaxLast', scope + '/weights/read'
]
self._AssertInputOpsAre(weights_quant, expected_inputs)
output_op_name = scope + '/MatMul'
@@ -184,8 +184,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
expected_inputs = [
- scope + '/conv_quant/min/read', scope + '/conv_quant/max/read',
- scope + '/BiasAdd'
+ scope + '/conv_quant/AssignMinEma',
+ scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd'
]
self._AssertInputOpsAre(conv_quant, expected_inputs)
output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
@@ -196,7 +196,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(act_quant.type, quantization_node_name)
expected_inputs = [
- 'test/act_quant/min/read', 'test/act_quant/max/read',
+ 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
'test/' + activation_op_name
]
self._AssertInputOpsAre(act_quant, expected_inputs)
@@ -247,7 +247,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(weights_quant.type, quantization_node_name)
expected_inputs = [
- scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum',
+ scope + '/weights_quant/AssignMinLast',
+ scope + '/weights_quant/AssignMaxLast',
scope + '/depthwise_weights/read'
]
self._AssertInputOpsAre(weights_quant, expected_inputs)
@@ -259,8 +260,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
expected_inputs = [
- scope + '/conv_quant/min/read', scope + '/conv_quant/max/read',
- scope + '/BiasAdd'
+ scope + '/conv_quant/AssignMinEma',
+ scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd'
]
self._AssertInputOpsAre(conv_quant, expected_inputs)
output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
@@ -271,7 +272,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(act_quant.type, quantization_node_name)
expected_inputs = [
- 'test/act_quant/min/read', 'test/act_quant/max/read',
+ 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
'test/' + activation_op_name
]
self._AssertInputOpsAre(act_quant, expected_inputs)
@@ -401,8 +402,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(weights_quant.type, quantization_node_name)
expected_inputs = [
- scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'),
- scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'),
+ scope + '/weights_quant/' + ('AssignMinEma'
+ if use_ema else 'AssignMinLast'),
+ scope + '/weights_quant/' + ('AssignMaxEma'
+ if use_ema else 'AssignMaxLast'),
scope + '/mul_fold'
]
self._AssertInputOpsAre(weights_quant, expected_inputs)
@@ -415,8 +418,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
expected_inputs = [
- scope + '/conv_quant/min/read', scope + '/conv_quant/max/read',
- scope + '/add_fold'
+ scope + '/conv_quant/AssignMinEma',
+ scope + '/conv_quant/AssignMaxEma', scope + '/add_fold'
]
self._AssertInputOpsAre(conv_quant, expected_inputs)
output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
@@ -427,7 +430,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(act_quant.type, quantization_node_name)
expected_inputs = [
- 'test/act_quant/min/read', 'test/act_quant/max/read',
+ 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
'test/' + activation_op_name
]
self._AssertInputOpsAre(act_quant, expected_inputs)
@@ -518,8 +521,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(weights_quant.type, quantization_node_name)
expected_inputs = [
- scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'),
- scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'),
+ scope + '/weights_quant/' + ('AssignMinEma'
+ if use_ema else 'AssignMinLast'),
+ scope + '/weights_quant/' + ('AssignMaxEma'
+ if use_ema else 'AssignMaxLast'),
scope + '/mul_fold'
]
self._AssertInputOpsAre(weights_quant, expected_inputs)
@@ -532,8 +537,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
expected_inputs = [
- scope + '/conv_quant/min/read', scope + '/conv_quant/max/read',
- scope + '/add_fold'
+ scope + '/conv_quant/AssignMinEma',
+ scope + '/conv_quant/AssignMaxEma', scope + '/add_fold'
]
self._AssertInputOpsAre(conv_quant, expected_inputs)
output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
@@ -544,7 +549,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(act_quant.type, quantization_node_name)
expected_inputs = [
- 'test/act_quant/min/read', 'test/act_quant/max/read',
+ 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
'test/' + activation_op_name
]
self._AssertInputOpsAre(act_quant, expected_inputs)
@@ -639,8 +644,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(weights_quant.type, quantization_node_name)
expected_inputs = [
- scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'),
- scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'),
+ scope + '/weights_quant/' + ('AssignMinEma'
+ if use_ema else 'AssignMinLast'),
+ scope + '/weights_quant/' + ('AssignMaxEma'
+ if use_ema else 'AssignMaxLast'),
scope + '/mul_fold'
]
self._AssertInputOpsAre(weights_quant, expected_inputs)
@@ -653,8 +660,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
expected_inputs = [
- scope + '/conv_quant/min/read', scope + '/conv_quant/max/read',
- scope + '/add_fold'
+ scope + '/conv_quant/AssignMinEma',
+ scope + '/conv_quant/AssignMaxEma', scope + '/add_fold'
]
self._AssertInputOpsAre(conv_quant, expected_inputs)
output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
@@ -665,7 +672,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(act_quant.type, quantization_node_name)
expected_inputs = [
- 'test/act_quant/min/read', 'test/act_quant/max/read',
+ 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
'test/' + activation_op_name
]
self._AssertInputOpsAre(act_quant, expected_inputs)
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py
index eb141a21bd..1e4dd7cf67 100644
--- a/tensorflow/contrib/quantize/python/quantize_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
conv2d = layers.conv2d
+separable_conv2d = layers.separable_conv2d
class QuantizeTest(test_util.TensorFlowTestCase):
@@ -77,6 +78,30 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantization_node_name)
self.assertEqual(add_quant.type, quantization_node_name)
+ def testInsertQuantOpForAddAfterSeparableConv2d(self):
+ graph = ops.Graph()
+ with graph.as_default():
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ input2 = array_ops.zeros((batch_size, height / 2, width / 2, depth))
+ conv = separable_conv2d(input1, None, [5, 5], stride=2,
+ depth_multiplier=1.0, padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None, scope='test/test')
+ node = math_ops.add(conv, input2, name='test/add')
+ node = array_ops.identity(node, name='test/identity')
+ update_barrier = control_flow_ops.no_op(name='update_barrier')
+ with ops.control_dependencies([update_barrier]):
+ array_ops.identity(node, name='control_dependency')
+
+ quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True,
+ activation_bits=8)
+
+ quantization_node_name = 'FakeQuantWithMinMaxVars'
+ add_quant = graph.get_operation_by_name('test/add_quant/' +
+ quantization_node_name)
+ self.assertEqual(add_quant.type, quantization_node_name)
+
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index 91493302b1..01a5540121 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope as vs
@@ -589,6 +590,24 @@ class AttentionWrapperTest(test.TestCase):
expected_final_alignment_history=expected_final_alignment_history,
name='testBahdanauMonotonicNormalized')
+ def testBahdanauMonotonicHard(self):
+ # Run attention mechanism with mode='hard', make sure probabilities are hard
+ b, t, u, d = 10, 20, 30, 40
+ with self.test_session(use_gpu=True) as sess:
+ a = wrapper.BahdanauMonotonicAttention(
+ d,
+ random_ops.random_normal((b, t, u)),
+ mode='hard')
+ # Just feed previous attention as [1, 0, 0, ...]
+ attn = a(random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t))
+ sess.run(variables.global_variables_initializer())
+ attn_out = attn.eval()
+ # All values should be 0 or 1
+ self.assertTrue(np.all(np.logical_or(attn_out == 0, attn_out == 1)))
+ # Sum of distributions should be 0 or 1 (0 when all p_choose_i are 0)
+ self.assertTrue(np.all(np.logical_or(attn_out.sum(axis=1) == 1,
+ attn_out.sum(axis=1) == 0)))
+
def testLuongMonotonicNotNormalized(self):
create_attention_mechanism = functools.partial(
wrapper.LuongMonotonicAttention, sigmoid_noise=1.0,
@@ -695,6 +714,24 @@ class AttentionWrapperTest(test.TestCase):
expected_final_alignment_history=expected_final_alignment_history,
name='testMultiAttention')
+ def testLuongMonotonicHard(self):
+ # Run attention mechanism with mode='hard', make sure probabilities are hard
+ b, t, u, d = 10, 20, 30, 40
+ with self.test_session(use_gpu=True) as sess:
+ a = wrapper.LuongMonotonicAttention(
+ d,
+ random_ops.random_normal((b, t, u)),
+ mode='hard')
+ # Just feed previous attention as [1, 0, 0, ...]
+ attn = a(random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t))
+ sess.run(variables.global_variables_initializer())
+ attn_out = attn.eval()
+ # All values should be 0 or 1
+ self.assertTrue(np.all(np.logical_or(attn_out == 0, attn_out == 1)))
+ # Sum of distributions should be 0 or 1 (0 when all p_choose_i are 0)
+ self.assertTrue(np.all(np.logical_or(attn_out.sum(axis=1) == 1,
+ attn_out.sum(axis=1) == 0)))
+
def testMultiAttentionNoAttentionLayer(self):
create_attention_mechanisms = (
wrapper.BahdanauAttention, wrapper.LuongAttention)
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 0c64c9caf1..c3b180d9f4 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -691,7 +691,11 @@ def _monotonic_probability_fn(score, previous_alignments, sigmoid_noise, mode,
seed=seed)
score += sigmoid_noise*noise
# Compute "choosing" probabilities from the attention scores
- p_choose_i = math_ops.sigmoid(score)
+ if mode == "hard":
+ # When mode is hard, use a hard sigmoid
+ p_choose_i = math_ops.cast(score > 0, score.dtype)
+ else:
+ p_choose_i = math_ops.sigmoid(score)
# Convert from choosing probabilities to attention distribution
return monotonic_attention(p_choose_i, previous_alignments, mode)
diff --git a/tensorflow/contrib/slim/BUILD b/tensorflow/contrib/slim/BUILD
index 23c23af2f4..c2f106c2b2 100644
--- a/tensorflow/contrib/slim/BUILD
+++ b/tensorflow/contrib/slim/BUILD
@@ -39,6 +39,8 @@ py_test(
"//tensorflow/python:summary",
"//tensorflow/python:training",
"//tensorflow/python:variables",
+ "//tensorflow/python/debug:debug_data",
+ "//tensorflow/python/debug:hooks",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/slim/python/slim/evaluation.py b/tensorflow/contrib/slim/python/slim/evaluation.py
index 2d4b08df61..cdb720b36b 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation.py
@@ -153,7 +153,8 @@ def evaluate_once(master,
summary_op=_USE_DEFAULT,
summary_op_feed_dict=None,
variables_to_restore=None,
- session_config=None):
+ session_config=None,
+ hooks=None):
"""Evaluates the model at the given checkpoint path.
Args:
@@ -177,6 +178,8 @@ def evaluate_once(master,
slim.variables.GetVariablesToRestore() is used.
session_config: An instance of `tf.ConfigProto` that will be used to
configure the `Session`. If left as `None`, the default will be used.
+ hooks: A list of additional `SessionRunHook` objects to pass during the
+ evaluation.
Returns:
The value of `final_op` or `None` if `final_op` is `None`.
@@ -184,11 +187,13 @@ def evaluate_once(master,
if summary_op == _USE_DEFAULT:
summary_op = summary.merge_all()
- hooks = [evaluation.StopAfterNEvalsHook(num_evals),]
+ all_hooks = [evaluation.StopAfterNEvalsHook(num_evals),]
if summary_op is not None:
- hooks.append(evaluation.SummaryAtEndHook(
+ all_hooks.append(evaluation.SummaryAtEndHook(
log_dir=logdir, summary_op=summary_op, feed_dict=summary_op_feed_dict))
+ if hooks is not None:
+ all_hooks.extend(hooks)
saver = None
if variables_to_restore is not None:
@@ -203,7 +208,7 @@ def evaluate_once(master,
feed_dict=eval_op_feed_dict,
final_ops=final_op,
final_ops_feed_dict=final_op_feed_dict,
- hooks=hooks,
+ hooks=all_hooks,
config=session_config)
@@ -256,7 +261,7 @@ def evaluation_loop(master,
configure the `Session`. If left as `None`, the default will be used.
timeout: The maximum amount of time to wait between checkpoints. If left as
`None`, then the process will wait indefinitely.
- hooks: A list of additional SessionRunHook objects to pass during
+ hooks: A list of additional `SessionRunHook` objects to pass during
repeated evaluations.
Returns:
diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py
index d9e0f54b72..870f504d10 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation_test.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import glob
import os
+import shutil
import time
import numpy as np
@@ -29,6 +30,8 @@ from tensorflow.contrib.metrics.python.ops import metric_ops
from tensorflow.contrib.slim.python.slim import evaluation
from tensorflow.contrib.training.python.training import evaluation as evaluation_lib
from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.wrappers import hooks
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -230,11 +233,7 @@ class SingleEvaluationTest(test.TestCase):
with self.assertRaises(errors.NotFoundError):
evaluation.evaluate_once('', checkpoint_path, log_dir)
- def testRestoredModelPerformance(self):
- checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
- log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/')
-
- # First, save out the current model to a checkpoint:
+ def _prepareCheckpoint(self, checkpoint_path):
init_op = control_flow_ops.group(variables.global_variables_initializer(),
variables.local_variables_initializer())
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1)
@@ -242,6 +241,13 @@ class SingleEvaluationTest(test.TestCase):
sess.run(init_op)
saver.save(sess, checkpoint_path)
+ def testRestoredModelPerformance(self):
+ checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
+ log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/')
+
+ # First, save out the current model to a checkpoint:
+ self._prepareCheckpoint(checkpoint_path)
+
# Next, determine the metric to evaluate:
value_op, update_op = metric_ops.streaming_accuracy(self._predictions,
self._labels)
@@ -251,6 +257,36 @@ class SingleEvaluationTest(test.TestCase):
'', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op)
self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
+ def testAdditionalHooks(self):
+ checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
+ log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/')
+
+ # First, save out the current model to a checkpoint:
+ self._prepareCheckpoint(checkpoint_path)
+
+ # Next, determine the metric to evaluate:
+ value_op, update_op = metric_ops.streaming_accuracy(self._predictions,
+ self._labels)
+
+ dumping_root = os.path.join(self.get_temp_dir(), 'tfdbg_dump_dir')
+ dumping_hook = hooks.DumpingDebugHook(dumping_root, log_usage=False)
+ try:
+ # Run the evaluation and verify the results:
+ accuracy_value = evaluation.evaluate_once(
+ '', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op,
+ hooks=[dumping_hook])
+ self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
+
+ dump = debug_data.DebugDumpDir(
+ glob.glob(os.path.join(dumping_root, 'run_*'))[0])
+ # Here we simply assert that the dumped data has been loaded and is
+ # non-empty. We do not care about the detailed model-internal tensors or
+ # their values.
+ self.assertTrue(dump.dumped_tensor_data)
+ finally:
+ if os.path.isdir(dumping_root):
+ shutil.rmtree(dumping_root)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD
index da23f1c380..3c60d2bb56 100644
--- a/tensorflow/contrib/summary/BUILD
+++ b/tensorflow/contrib/summary/BUILD
@@ -26,12 +26,18 @@ py_test(
deps = [
":summary_ops",
":summary_test_util",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:errors",
+ "//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:ops",
"//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
"//tensorflow/python:training",
"//tensorflow/python/eager:function",
"//tensorflow/python/eager:test",
+ "@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py
index ca82ea094c..813e8b2b09 100644
--- a/tensorflow/contrib/summary/summary.py
+++ b/tensorflow/contrib/summary/summary.py
@@ -28,11 +28,13 @@ from __future__ import print_function
from tensorflow.contrib.summary.summary_ops import all_summary_ops
from tensorflow.contrib.summary.summary_ops import always_record_summaries
from tensorflow.contrib.summary.summary_ops import audio
+from tensorflow.contrib.summary.summary_ops import create_summary_db_writer
from tensorflow.contrib.summary.summary_ops import create_summary_file_writer
from tensorflow.contrib.summary.summary_ops import eval_dir
from tensorflow.contrib.summary.summary_ops import generic
from tensorflow.contrib.summary.summary_ops import histogram
from tensorflow.contrib.summary.summary_ops import image
+from tensorflow.contrib.summary.summary_ops import import_event
from tensorflow.contrib.summary.summary_ops import never_record_summaries
from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps
from tensorflow.contrib.summary.summary_ops import scalar
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index 56e3198593..f6be99f6ae 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -19,7 +19,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import getpass
import os
+import re
+import time
+
+import six
from tensorflow.contrib.summary import gen_summary_ops
from tensorflow.python.eager import context
@@ -42,6 +47,10 @@ _SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries"
_SUMMARY_COLLECTION_NAME = "_SUMMARY_V2"
_SUMMARY_WRITER_INIT_COLLECTION_NAME = "_SUMMARY_WRITER_V2"
+_EXPERIMENT_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,256}$")
+_RUN_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,512}$")
+_USER_NAME_PATTERNS = re.compile(r"^[a-z]([-a-z0-9]{0,29}[a-z0-9])?$", re.I)
+
def should_record_summaries():
"""Returns boolean Tensor which is true if summaries should be recorded."""
@@ -57,12 +66,14 @@ def should_record_summaries():
# TODO(apassos) consider how to handle local step here.
@tf_contextlib.contextmanager
-def record_summaries_every_n_global_steps(n):
+def record_summaries_every_n_global_steps(n, global_step=None):
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
+ if global_step is None:
+ global_step = training_util.get_global_step()
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
with ops.device("cpu:0"):
- collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)]
+ collection_ref[:] = [math_ops.equal(global_step % n, 0)]
yield
collection_ref[:] = old
@@ -130,7 +141,8 @@ def create_summary_file_writer(logdir,
flush once the queue gets bigger than this.
flush_millis: the largest interval between flushes.
filename_suffix: optional suffix for the event file name.
- name: name for the summary writer.
+ name: Shared name for this SummaryWriter resource stored to default
+ Graph.
Returns:
Either a summary writer or an empty object which can be used as a
@@ -145,14 +157,81 @@ def create_summary_file_writer(logdir,
flush_millis = constant_op.constant(2 * 60 * 1000)
if filename_suffix is None:
filename_suffix = constant_op.constant("")
- resource = gen_summary_ops.summary_writer(shared_name=name)
- # TODO(apassos) ensure the initialization op runs when in graph mode;
- # consider calling session.run here.
- ops.add_to_collection(
- _SUMMARY_WRITER_INIT_COLLECTION_NAME,
- gen_summary_ops.create_summary_file_writer(
- resource, logdir, max_queue, flush_millis, filename_suffix))
- return SummaryWriter(resource)
+ return _make_summary_writer(
+ name,
+ gen_summary_ops.create_summary_file_writer,
+ logdir=logdir,
+ max_queue=max_queue,
+ flush_millis=flush_millis,
+ filename_suffix=filename_suffix)
+
+
+def create_summary_db_writer(db_uri,
+ experiment_name=None,
+ run_name=None,
+ user_name=None,
+ name=None):
+ """Creates a summary database writer in the current context.
+
+ This can be used to write tensors from the execution graph directly
+ to a database. Only SQLite is supported right now. This function
+ will create the schema if it doesn't exist. Entries in the Users,
+ Experiments, and Runs tables will be created automatically if they
+ don't already exist.
+
+ Args:
+ db_uri: For example "file:/tmp/foo.sqlite".
+ experiment_name: Defaults to YYYY-MM-DD in local time if None.
+ Empty string means the Run will not be associated with an
+ Experiment. Can't contain ASCII control characters or <>. Case
+ sensitive.
+ run_name: Defaults to HH:MM:SS in local time if None. Empty string
+ means a Tag will not be associated with any Run. Can't contain
+ ASCII control characters or <>. Case sensitive.
+ user_name: Defaults to system username if None. Empty means the
+ Experiment will not be associated with a User. Must be valid as
+ both a DNS label and Linux username.
+ name: Shared name for this SummaryWriter resource stored to default
+ Graph.
+
+ Returns:
+ A new SummaryWriter instance.
+ """
+ with ops.device("cpu:0"):
+ if experiment_name is None:
+ experiment_name = time.strftime("%Y-%m-%d", time.localtime(time.time()))
+ if run_name is None:
+ run_name = time.strftime("%H:%M:%S", time.localtime(time.time()))
+ if user_name is None:
+ user_name = getpass.getuser()
+ experiment_name = _cleanse_string(
+ "experiment_name", _EXPERIMENT_NAME_PATTERNS, experiment_name)
+ run_name = _cleanse_string("run_name", _RUN_NAME_PATTERNS, run_name)
+ user_name = _cleanse_string("user_name", _USER_NAME_PATTERNS, user_name)
+ return _make_summary_writer(
+ name,
+ gen_summary_ops.create_summary_db_writer,
+ db_uri=db_uri,
+ experiment_name=experiment_name,
+ run_name=run_name,
+ user_name=user_name)
+
+
+def _make_summary_writer(name, factory, **kwargs):
+ resource = gen_summary_ops.summary_writer(shared_name=name)
+ # TODO(apassos): Consider doing this instead.
+ # node = factory(resource, **kwargs)
+ # if not context.in_eager_mode():
+ # ops.get_default_session().run(node)
+ ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME,
+ factory(resource, **kwargs))
+ return SummaryWriter(resource)
+
+
+def _cleanse_string(name, pattern, value):
+ if isinstance(value, six.string_types) and pattern.search(value) is None:
+ raise ValueError("%s (%s) must match %s" % (name, value, pattern.pattern))
+ return ops.convert_to_tensor(value, dtypes.string)
def _nothing():
@@ -204,68 +283,81 @@ def summary_writer_function(name, tensor, function, family=None):
return op
-def generic(name, tensor, metadata, family=None):
+def generic(name, tensor, metadata=None, family=None, global_step=None):
"""Writes a tensor summary if possible."""
-
+ if global_step is None:
+ global_step = training_util.get_global_step()
def function(tag, scope):
+ if metadata is None:
+ serialized_metadata = constant_op.constant("")
+ elif hasattr(metadata, "SerializeToString"):
+ serialized_metadata = constant_op.constant(metadata.SerializeToString())
+ else:
+ serialized_metadata = metadata
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_summary(
context.context().summary_writer_resource,
- training_util.get_global_step(), array_ops.identity(tensor),
- tag, metadata, name=scope)
+ global_step, array_ops.identity(tensor),
+ tag, serialized_metadata, name=scope)
return summary_writer_function(name, tensor, function, family=family)
-def scalar(name, tensor, family=None):
+def scalar(name, tensor, family=None, global_step=None):
"""Writes a scalar summary if possible."""
-
+ if global_step is None:
+ global_step = training_util.get_global_step()
def function(tag, scope):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_scalar_summary(
context.context().summary_writer_resource,
- training_util.get_global_step(), tag, array_ops.identity(tensor),
+ global_step, tag, array_ops.identity(tensor),
name=scope)
return summary_writer_function(name, tensor, function, family=family)
-def histogram(name, tensor, family=None):
+def histogram(name, tensor, family=None, global_step=None):
"""Writes a histogram summary if possible."""
-
+ if global_step is None:
+ global_step = training_util.get_global_step()
def function(tag, scope):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_histogram_summary(
context.context().summary_writer_resource,
- training_util.get_global_step(), tag, array_ops.identity(tensor),
+ global_step, tag, array_ops.identity(tensor),
name=scope)
return summary_writer_function(name, tensor, function, family=family)
-def image(name, tensor, bad_color=None, max_images=3, family=None):
+def image(name, tensor, bad_color=None, max_images=3, family=None,
+ global_step=None):
"""Writes an image summary if possible."""
-
+ if global_step is None:
+ global_step = training_util.get_global_step()
def function(tag, scope):
bad_color_ = (constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
if bad_color is None else bad_color)
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_image_summary(
context.context().summary_writer_resource,
- training_util.get_global_step(), tag, array_ops.identity(tensor),
+ global_step, tag, array_ops.identity(tensor),
bad_color_,
max_images, name=scope)
return summary_writer_function(name, tensor, function, family=family)
-def audio(name, tensor, sample_rate, max_outputs, family=None):
+def audio(name, tensor, sample_rate, max_outputs, family=None,
+ global_step=None):
"""Writes an audio summary if possible."""
-
+ if global_step is None:
+ global_step = training_util.get_global_step()
def function(tag, scope):
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_audio_summary(
context.context().summary_writer_resource,
- training_util.get_global_step(),
+ global_step,
tag,
array_ops.identity(tensor),
sample_rate=sample_rate,
@@ -275,6 +367,26 @@ def audio(name, tensor, sample_rate, max_outputs, family=None):
return summary_writer_function(name, tensor, function, family=family)
+def import_event(tensor, name=None):
+ """Writes a tf.Event binary proto.
+
+ When using create_summary_db_writer(), this can be used alongside
+ tf.TFRecordReader to load event logs into the database. Please note
+ that this is lower level than the other summary functions and will
+ ignore any conditions set by methods like should_record_summaries().
+
+ Args:
+ tensor: A `Tensor` of type `string` containing a serialized `Event`
+ proto.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created Operation.
+ """
+ return gen_summary_ops.import_event(
+ context.context().summary_writer_resource, tensor, name=name)
+
+
def eval_dir(model_dir, name=None):
"""Construct a logdir for an eval summary writer."""
return os.path.join(model_dir, "eval" if not name else "eval_" + name)
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index de7ae6ec27..6e1a746815 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -17,14 +17,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+import os
import tempfile
+import six
+import sqlite3
+
from tensorflow.contrib.summary import summary_ops
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import function
from tensorflow.python.eager import test
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.platform import gfile
from tensorflow.python.training import training_util
@@ -86,6 +94,120 @@ class TargetTest(test_util.TensorFlowTestCase):
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'scalar')
+ def testSummaryGlobalStep(self):
+ global_step = training_util.get_or_create_global_step()
+ logdir = tempfile.mkdtemp()
+ with summary_ops.create_summary_file_writer(
+ logdir, max_queue=0,
+ name='t2').as_default(), summary_ops.always_record_summaries():
+
+ summary_ops.scalar('scalar', 2.0, global_step=global_step)
+
+ events = summary_test_util.events_from_file(logdir)
+ self.assertEqual(len(events), 2)
+ self.assertEqual(events[1].summary.value[0].tag, 'scalar')
+
+
+class DbTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
+ if os.path.exists(self.db_path):
+ os.unlink(self.db_path)
+ self.db = sqlite3.connect(self.db_path)
+ self.create_summary_db_writer = functools.partial(
+ summary_ops.create_summary_db_writer,
+ db_uri=self.db_path,
+ experiment_name='experiment',
+ run_name='run',
+ user_name='user')
+
+ def tearDown(self):
+ self.db.close()
+
+ def testIntegerSummaries(self):
+ step = training_util.create_global_step()
+
+ def adder(x, y):
+ state_ops.assign_add(step, 1)
+ summary_ops.generic('x', x)
+ summary_ops.generic('y', y)
+ sum_ = x + y
+ summary_ops.generic('sum', sum_)
+ return sum_
+
+ with summary_ops.always_record_summaries():
+ with self.create_summary_db_writer().as_default():
+ self.assertEqual(5, adder(int64(2), int64(3)).numpy())
+
+ six.assertCountEqual(self, [1, 1, 1],
+ get_all(self.db, 'SELECT step FROM Tensors'))
+ six.assertCountEqual(self, ['x', 'y', 'sum'],
+ get_all(self.db, 'SELECT tag_name FROM Tags'))
+ x_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "x"')
+ y_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "y"')
+ sum_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "sum"')
+
+ with summary_ops.always_record_summaries():
+ with self.create_summary_db_writer().as_default():
+ self.assertEqual(9, adder(int64(4), int64(5)).numpy())
+
+ six.assertCountEqual(self, [1, 1, 1, 2, 2, 2],
+ get_all(self.db, 'SELECT step FROM Tensors'))
+ six.assertCountEqual(self, [x_id, y_id, sum_id],
+ get_all(self.db, 'SELECT tag_id FROM Tags'))
+ self.assertEqual(2, get_tensor(self.db, x_id, 1))
+ self.assertEqual(3, get_tensor(self.db, y_id, 1))
+ self.assertEqual(5, get_tensor(self.db, sum_id, 1))
+ self.assertEqual(4, get_tensor(self.db, x_id, 2))
+ self.assertEqual(5, get_tensor(self.db, y_id, 2))
+ self.assertEqual(9, get_tensor(self.db, sum_id, 2))
+ six.assertCountEqual(
+ self, ['experiment'],
+ get_all(self.db, 'SELECT experiment_name FROM Experiments'))
+ six.assertCountEqual(self, ['run'],
+ get_all(self.db, 'SELECT run_name FROM Runs'))
+ six.assertCountEqual(self, ['user'],
+ get_all(self.db, 'SELECT user_name FROM Users'))
+
+ def testBadExperimentName(self):
+ with self.assertRaises(ValueError):
+ self.create_summary_db_writer(experiment_name='\0')
+
+ def testBadRunName(self):
+ with self.assertRaises(ValueError):
+ self.create_summary_db_writer(run_name='\0')
+
+ def testBadUserName(self):
+ with self.assertRaises(ValueError):
+ self.create_summary_db_writer(user_name='-hi')
+ with self.assertRaises(ValueError):
+ self.create_summary_db_writer(user_name='hi-')
+ with self.assertRaises(ValueError):
+ self.create_summary_db_writer(user_name='@')
+
+
+def get_one(db, q, *p):
+ return db.execute(q, p).fetchone()[0]
+
+
+def get_all(db, q, *p):
+ return unroll(db.execute(q, p).fetchall())
+
+
+def get_tensor(db, tag_id, step):
+ return get_one(
+ db, 'SELECT tensor FROM Tensors WHERE tag_id = ? AND step = ?', tag_id,
+ step)
+
+
+def int64(x):
+ return array_ops.constant(x, dtypes.int64)
+
+
+def unroll(list_of_tuples):
+ return sum(list_of_tuples, ())
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD
index d8bbf87d2c..068e862650 100644
--- a/tensorflow/contrib/tensorboard/db/BUILD
+++ b/tensorflow/contrib/tensorboard/db/BUILD
@@ -45,10 +45,12 @@ cc_library(
tf_cc_test(
name = "summary_db_writer_test",
+ size = "small",
srcs = ["summary_db_writer_test.cc"],
deps = [
":summary_db_writer",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/db:sqlite",
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
index df64e36305..a26ad61660 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
@@ -15,10 +15,12 @@ limitations under the License.
#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
#include "tensorflow/contrib/tensorboard/db/schema.h"
+#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/db/sqlite.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/snappy.h"
+#include "tensorflow/core/util/event.pb.h"
namespace tensorflow {
namespace {
@@ -86,13 +88,19 @@ class SummaryDbWriter : public SummaryWriterInterface {
TF_RETURN_IF_ERROR(BindTensor(t));
break;
}
- TF_RETURN_IF_ERROR(insert_tensor_.StepAndReset());
- return Status::OK();
+ return insert_tensor_.StepAndReset();
}
Status WriteEvent(std::unique_ptr<Event> e) override {
- // TODO(@jart): This will be used to load event logs.
- return errors::Unimplemented("WriteEvent");
+ mutex_lock ml(mu_);
+ TF_RETURN_IF_ERROR(InitializeParents());
+ if (e->what_case() == Event::WhatCase::kSummary) {
+ const Summary& summary = e->summary();
+ for (int i = 0; i < summary.value_size(); ++i) {
+ TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i)));
+ }
+ }
+ return Status::OK();
}
Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
@@ -247,6 +255,24 @@ class SummaryDbWriter : public SummaryWriterInterface {
return Status::OK();
}
+ Status WriteSummary(const Event* e, const Summary::Value& summary)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 tag_id;
+ TF_RETURN_IF_ERROR(GetTagId(run_id_, summary.tag(), &tag_id));
+ insert_tensor_.BindInt(1, tag_id);
+ insert_tensor_.BindInt(2, e->step());
+ insert_tensor_.BindDouble(3, e->wall_time());
+ switch (summary.value_case()) {
+ case Summary::Value::ValueCase::kSimpleValue:
+ insert_tensor_.BindDouble(4, summary.simple_value());
+ break;
+ default:
+ // TODO(@jart): Handle the rest.
+ return Status::OK();
+ }
+ return insert_tensor_.StepAndReset();
+ }
+
mutex mu_;
Env* env_;
std::shared_ptr<Sqlite> db_ GUARDED_BY(mu_);
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
index d32904f97c..c1af51e7b7 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
@@ -14,14 +14,19 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
+#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/db/sqlite.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/event.pb.h"
namespace tensorflow {
namespace {
+const float kTolerance = 1e-5;
+
Tensor MakeScalarInt64(int64 x) {
Tensor t(DT_INT64, TensorShape({}));
t.scalar<int64>()() = x;
@@ -41,7 +46,7 @@ class FakeClockEnv : public EnvWrapper {
class SummaryDbWriterTest : public ::testing::Test {
protected:
- void SetUp() override { db_ = Sqlite::Open("file::memory:").ValueOrDie(); }
+ void SetUp() override { db_ = Sqlite::Open(":memory:").ValueOrDie(); }
void TearDown() override {
if (writer_ != nullptr) {
@@ -158,5 +163,54 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) {
QueryString("SELECT tensor FROM Tensors WHERE step = 2").empty());
}
+TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) {
+ TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
+ TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy",
+ "this-is-metaaa"));
+ TF_ASSERT_OK(writer_->Flush());
+ ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users"));
+ ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
+ ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs"));
+ ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
+ ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+}
+
+TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
+ TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
+ std::unique_ptr<Event> e{new Event};
+ e->set_step(7);
+ e->set_wall_time(123.456);
+ Summary::Value* s = e->mutable_summary()->add_value();
+ s->set_tag("π");
+ s->set_simple_value(3.14f);
+ s = e->mutable_summary()->add_value();
+ s->set_tag("φ");
+ s->set_simple_value(1.61f);
+ TF_ASSERT_OK(writer_->WriteEvent(std::move(e)));
+ TF_ASSERT_OK(writer_->Flush());
+ ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags"));
+ ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+ int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'π'");
+ int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'φ'");
+ EXPECT_GT(tag1_id, 0LL);
+ EXPECT_GT(tag2_id, 0LL);
+ EXPECT_EQ(123.456, QueryDouble(strings::StrCat(
+ "SELECT computed_time FROM Tensors WHERE tag_id = ",
+ tag1_id, " AND step = 7")));
+ EXPECT_EQ(123.456, QueryDouble(strings::StrCat(
+ "SELECT computed_time FROM Tensors WHERE tag_id = ",
+ tag2_id, " AND step = 7")));
+ EXPECT_NEAR(3.14,
+ QueryDouble(strings::StrCat(
+ "SELECT tensor FROM Tensors WHERE tag_id = ", tag1_id,
+ " AND step = 7")),
+ kTolerance); // Summary::simple_value is float
+ EXPECT_NEAR(1.61,
+ QueryDouble(strings::StrCat(
+ "SELECT tensor FROM Tensors WHERE tag_id = ", tag2_id,
+ " AND step = 7")),
+ kTolerance);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 3965c087a1..916b9b3082 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -45,10 +45,7 @@ class TPUConfig(
is invoked once on each host. To be precise, with a global batch size
`train_batch_size` in `TPUEstimator` constructor, the batch size for each
shard is `train_batch_size` // #hosts. With Per-Core input pipeline
- deployment, the shard batch size is `train_batch_size` // #cores. Note
- that this only works for single-host TPU training now (tracked in
- b/67051042). For multi-host, please use Per-Core, i.e., `False` for
- `per_host_input_for_training`.
+ deployment, the shard batch size is `train_batch_size` // #cores.
tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
within TPUEstimator, however when using ClusterSpec propagation in more
esoteric cluster configurations, you may need to specify the job name as a
@@ -109,3 +106,12 @@ class RunConfig(run_config_lib.RunConfig):
@property
def tpu_config(self):
return self._tpu_config
+
+ def replace(self, **kwargs):
+ if 'tpu_config' not in kwargs:
+ return super(RunConfig, self).replace(**kwargs)
+
+ tpu_config = kwargs.pop('tpu_config')
+ new_instance = super(RunConfig, self).replace(**kwargs)
+ new_instance._tpu_config = tpu_config # pylint: disable=protected-access
+ return new_instance
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 060b3f9129..07877fcc76 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -66,7 +66,7 @@ _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
# TODO(b/65703635): Flip the value and remove all dead code.
-_WRAP_INPUT_FN_INTO_WHILE_LOOP = True
+_WRAP_INPUT_FN_INTO_WHILE_LOOP = False
def _create_global_step(graph):
@@ -232,8 +232,10 @@ class _TPUContext(object):
mode == model_fn_lib.ModeKeys.TRAIN
else self._eval_batch_size)
# On TPU
- return (global_batch_size // self.num_cores
- if self.is_input_sharded_per_core() else global_batch_size)
+ if self.is_input_sharded_per_core():
+ return global_batch_size // self.num_cores
+ else:
+ return global_batch_size // self.num_hosts
@property
def batch_size_for_model_fn(self):
@@ -535,13 +537,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
session, self._dequeue_ops)
def before_run(self, run_context):
- logging.info('Enqueue next batch of data to infeed.')
-
iterations = run_context.session.run(self._iterations_per_loop_var)
+
+ logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
+
self._infeed_thd_controller.send_next_batch_signal(iterations)
if self._dequeue_ops is not None:
# TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop.
- logging.info('Dequeue next batch of data from outfeed.')
+ logging.info(
+ 'Dequeue next (%d) batch(es) of data from outfeed.', iterations)
self._outfeed_thd_controller.send_next_batch_signal(iterations)
def end(self, session):
@@ -680,6 +684,40 @@ def generate_per_core_enqueue_ops_fn_for_host(
return enqueue_ops_fn, (lambda: infeed_queue_holder['instance'])
+def generate_per_host_enqueue_ops_fn_for_host(
+ ctx, input_fn, inputs_structure_recorder, batch_axis, device):
+ """Generates infeed enqueue ops for per-host input_fn on a single host."""
+ infeed_queue_holder = {'instance': None}
+
+ def enqueue_ops_fn():
+ with ops.device(device):
+ num_cores_per_host = ctx.num_of_cores_per_host
+ inputs = input_fn()
+ if isinstance(inputs, tuple):
+ features, labels = inputs
+ else:
+ features, labels = inputs, None
+ inputs_structure_recorder.validate_and_record_structure(
+ features, labels)
+ unsharded_tensor_list = (
+ inputs_structure_recorder.flatten_features_and_labels(
+ features, labels))
+
+ infeed_queue = tpu_feed.InfeedQueue(
+ tuple_types=[t.dtype for t in unsharded_tensor_list],
+ tuple_shapes=[t.shape for t in unsharded_tensor_list],
+ shard_dimensions=batch_axis)
+ infeed_queue_holder['instance'] = infeed_queue
+ infeed_queue.set_number_of_shards(num_cores_per_host)
+
+ per_host_enqueue_ops = (
+ infeed_queue.split_inputs_and_generate_enqueue_ops(
+ unsharded_tensor_list,
+ placement_function=lambda x: device))
+ return per_host_enqueue_ops
+ return enqueue_ops_fn, (lambda: infeed_queue_holder['instance'])
+
+
class _InputPipeline(object):
"""`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.
@@ -842,6 +880,8 @@ class _InputPipeline(object):
# structure is recorded.
enqueue_ops = self._invoke_input_fn_and_record_structure()
+ self._validate_input_pipeline()
+
def dequeue_fn():
"""dequeue_fn is used by TPU to retrieve the tensors."""
values = self._infeed_queue.generate_dequeue_op()
@@ -852,15 +892,15 @@ class _InputPipeline(object):
return (enqueue_ops, dequeue_fn)
def _invoke_input_fn_and_record_structure(self):
+ """Deploys the input pipeline and record input structure."""
+ enqueue_ops = []
+ infeed_queues = []
+ num_hosts = self._ctx.num_hosts
+ tpu_host_placement_fn = self._ctx.tpu_host_placement_function
if self._sharded_per_core:
# Per-Core input pipeline deployment.
- tpu_host_placement_fn = self._ctx.tpu_host_placement_function
- enqueue_ops = []
- infeed_queues = []
-
# Invoke input pipeline for each core and placed on the corresponding
# host.
- num_hosts = self._ctx.num_hosts
for host_id in range(num_hosts):
host_device = tpu_host_placement_fn(host_id=host_id)
with ops.device(host_device):
@@ -877,48 +917,52 @@ class _InputPipeline(object):
# Infeed_queue_getter must be called after enqueue_ops_fn is called.
infeed_queues.append(infeed_queue_getter())
- # infeed_queue is used to generate dequeue ops. The only thing it uses for
- # dequeue is dtypes and types. So, any one can be used. Here, grab the
- # first one.
- self._infeed_queue = infeed_queues[0]
- return enqueue_ops
-
else:
- # TODO(b/67051042): Extend this to multi-host support.
- host_id = 0
- host_device = self._ctx.tpu_host_placement_function(host_id=host_id)
- def enqueue_fn():
+ for host_id in range(num_hosts):
+ host_device = tpu_host_placement_fn(host_id=host_id)
with ops.device(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)):
- inputs = self._input_fn()
- if isinstance(inputs, tuple):
- features, labels = inputs
- else:
- features, labels = inputs, None
- self._inputs_structure_recorder.validate_and_record_structure(
- features, labels)
- unsharded_tensor_list = (
- self._inputs_structure_recorder.flatten_features_and_labels(
- features, labels))
-
- self._infeed_queue = tpu_feed.InfeedQueue(
- tuple_types=[t.dtype for t in unsharded_tensor_list],
- tuple_shapes=[t.shape for t in unsharded_tensor_list],
- shard_dimensions=self._batch_axis)
- self._infeed_queue.set_number_of_shards(self._ctx.num_cores)
-
- def placement_fn(core_id):
- return self._ctx.tpu_host_placement_function(core_id=core_id)
- return (
- self._infeed_queue.split_inputs_and_generate_enqueue_ops(
- unsharded_tensor_list,
- placement_function=placement_fn))
+ enqueue_ops_fn, infeed_queue_getter = (
+ generate_per_host_enqueue_ops_fn_for_host(
+ self._ctx, self._input_fn, self._inputs_structure_recorder,
+ self._batch_axis, host_device))
+ if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
+ enqueue_ops.append(_wrap_computation_in_while_loop(
+ device=host_device, op_fn=enqueue_ops_fn))
+ else:
+ enqueue_ops.append(enqueue_ops_fn())
+ infeed_queues.append(infeed_queue_getter())
+ # infeed_queue is used to generate dequeue ops. The only thing it uses for
+ # dequeue is dtypes and types. So, any one can be used. Here, grab the
+ # first one.
+ self._infeed_queue = infeed_queues[0]
+ return enqueue_ops
+
+ def _validate_input_pipeline(self):
+ # Perform some sanity checks to log user friendly information. We should
+ # error out to give users better error message. But, if
+ # _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break
+ # user code, so, log a warning.
+ if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
+ err_msg = ('Input pipeline contains one or more QueueRunners. '
+ 'These are not supported via TPUEstimator. You must convert '
+ 'your input pipeline to use `tf.data` instead (see '
+ 'https://www.tensorflow.org/programmers_guide/datasets for '
+ 'instructions.')
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
- return _wrap_computation_in_while_loop(device=host_device,
- op_fn=enqueue_fn)
+ raise RuntimeError(err_msg)
else:
- return enqueue_fn()
+ logging.warn(err_msg)
+ elif ops.get_default_graph().get_collection(ops.GraphKeys.SUMMARIES):
+ # Queue Runner has summary Ops by default. So here we use elif to do
+ # necessary checks for Dataset input pipeline only.
+ err_msg = ('Input pipeline contains `tf.summary` operations. '
+ 'These are not currently supported.')
+ if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
+ raise RuntimeError(err_msg)
+ else:
+ logging.warn(err_msg)
class _ModelFnWrapper(object):
@@ -1396,12 +1440,6 @@ class TPUEstimator(estimator_lib.Estimator):
'eval batch size {} must be divisible by number of shards {}'
.format(eval_batch_size, config.tpu_config.num_shards))
- if (config.tpu_config.num_shards > 8 and
- config.tpu_config.per_host_input_for_training):
- # TODO(b/67051042): Support per_host input pipelines when num_shards > 8
- raise NotImplementedError(
- 'Per-host input pipelines only available for num_shards <= 8')
-
# Verifies the model_fn signature according to Estimator framework.
estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access
# We cannot store config and params in this constructor as parent
diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py
index 391899b34f..7db625cdd5 100644
--- a/tensorflow/contrib/training/python/training/hparam.py
+++ b/tensorflow/contrib/training/python/training/hparam.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import json
+import numbers
import re
import six
@@ -76,7 +77,7 @@ def _process_scalar_value(name, parse_fn, var_type, m_dict, values,
function.
Raises:
- ValueError: If the name has already been sued.
+ ValueError: If the name has already been used.
"""
try:
parsed_value = parse_fn(m_dict['val'])
@@ -138,6 +139,54 @@ def _process_list_value(name, parse_fn, var_type, m_dict, values,
_parse_fail(name, var_type, m_dict['vals'], values)
+def _cast_to_type_if_compatible(name, param_type, value):
+ """Cast hparam to the provided type, if compatible.
+
+ Args:
+ name: Name of the hparam to be cast.
+ param_type: The type of the hparam.
+ value: The value to be cast, if compatible.
+
+ Returns:
+ The result of casting `value` to `param_type`.
+
+ Raises:
+ ValueError: If the type of `value` is not compatible with param_type.
+ * If `param_type` is a string type, but `value` is not.
+ * If `param_type` is a boolean, but `value` is not, or vice versa.
+ * If `param_type` is an integer type, but `value` is not.
+ * If `param_type` is a float type, but `value` is not a numeric type.
+ """
+ fail_msg = (
+ "Could not cast hparam '%s' of type '%s' from value %r" %
+ (name, param_type, value))
+
+ # Some callers use None, for which we can't do any casting/checking. :(
+ if issubclass(param_type, type(None)):
+ return value
+
+ # Avoid converting a non-string type to a string.
+ if (issubclass(param_type, (six.string_types, six.binary_type)) and
+ not isinstance(value, (six.string_types, six.binary_type))):
+ raise ValueError(fail_msg)
+
+ # Avoid converting a number or string type to a boolean or vice versa.
+ if issubclass(param_type, bool) != isinstance(value, bool):
+ raise ValueError(fail_msg)
+
+ # Avoid converting float to an integer (the reverse is fine).
+ if (issubclass(param_type, numbers.Integral) and
+ not isinstance(value, numbers.Integral)):
+ raise ValueError(fail_msg)
+
+ # Avoid converting a non-numeric type to a numeric type.
+ if (issubclass(param_type, numbers.Number) and
+ not isinstance(value, numbers.Number)):
+ raise ValueError(fail_msg)
+
+ return param_type(value)
+
+
def parse_values(values, type_map):
"""Parses hyperparameter values from a string into a python map.
@@ -438,17 +487,18 @@ class HParams(object):
Raises:
ValueError: If there is a type mismatch.
"""
- _, is_list = self._hparam_types[name]
+ param_type, is_list = self._hparam_types[name]
if isinstance(value, list):
if not is_list:
raise ValueError(
'Must not pass a list for single-valued parameter: %s' % name)
- setattr(self, name, value)
+ setattr(self, name, [
+ _cast_to_type_if_compatible(name, param_type, v) for v in value])
else:
if is_list:
raise ValueError(
'Must pass a list for multi-valued parameter: %s.' % name)
- setattr(self, name, value)
+ setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
def parse(self, values):
"""Override hyperparameter values, parsing new values from a string.
diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py
index f54514cefd..949c262f5b 100644
--- a/tensorflow/contrib/training/python/training/hparam_test.py
+++ b/tensorflow/contrib/training/python/training/hparam_test.py
@@ -318,13 +318,42 @@ class HParamsTest(test.TestCase):
self.assertEqual(3.0, hparams.b)
self.assertEqual('relu4', hparams.c_c)
- def testSetHParamTypeMismatch(self):
+ def testSetHParamListNonListMismatch(self):
hparams = hparam.HParams(a=1, b=[2.0, 3.0])
with self.assertRaisesRegexp(ValueError, r'Must not pass a list'):
hparams.set_hparam('a', [1.0])
with self.assertRaisesRegexp(ValueError, r'Must pass a list'):
hparams.set_hparam('b', 1.0)
+ def testSetHParamTypeMismatch(self):
+ hparams = hparam.HParams(
+ int_=1, str_='str', bool_=True, float_=1.1, list_int=[1, 2], none=None)
+
+ with self.assertRaises(ValueError):
+ hparams.set_hparam('str_', 2.2)
+
+ with self.assertRaises(ValueError):
+ hparams.set_hparam('int_', False)
+
+ with self.assertRaises(ValueError):
+ hparams.set_hparam('bool_', 1)
+
+ with self.assertRaises(ValueError):
+ hparams.set_hparam('int_', 2.2)
+
+ with self.assertRaises(ValueError):
+ hparams.set_hparam('list_int', [2, 3.3])
+
+ with self.assertRaises(ValueError):
+ hparams.set_hparam('int_', '2')
+
+ # Casting int to float is OK
+ hparams.set_hparam('float_', 1)
+
+ # Getting stuck with NoneType :(
+ hparams.set_hparam('none', '1')
+ self.assertEqual('1', hparams.none)
+
def testNonProtoFails(self):
with self.assertRaisesRegexp(AssertionError, ''):
hparam.HParams(hparam_def=1)
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 38fe247521..6399b8cf55 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -296,12 +296,13 @@ void* BFCAllocator::FindChunkPtr(BinNum bin_num, size_t rounded_bytes,
// it from the free bin structure prior to using.
RemoveFreeChunkIterFromBin(&b->free_chunks, citer);
- // If we can break the size of the chunk into two reasonably
- // large pieces, do so.
- //
- // TODO(vrv): What should be the criteria when deciding when
- // to split?
- if (chunk->size >= rounded_bytes * 2) {
+ // If we can break the size of the chunk into two reasonably large
+ // pieces, do so. In any case don't waste more than
+ // kMaxInternalFragmentation bytes on padding this alloc.
+ const int64 kMaxInternalFragmentation = 128 << 20; // 128mb
+ if (chunk->size >= rounded_bytes * 2 ||
+ static_cast<int64>(chunk->size) - rounded_bytes >=
+ kMaxInternalFragmentation) {
SplitChunk(h, rounded_bytes);
chunk = ChunkFromHandle(h); // Update chunk pointer in case it moved
}
diff --git a/tensorflow/core/framework/bfloat16.cc b/tensorflow/core/framework/bfloat16.cc
index a5ac0e1a8d..1a6f355c77 100644
--- a/tensorflow/core/framework/bfloat16.cc
+++ b/tensorflow/core/framework/bfloat16.cc
@@ -18,32 +18,24 @@ limitations under the License.
namespace tensorflow {
void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) {
- const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
- uint16_t* q = reinterpret_cast<uint16_t*>(dst);
-#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
- for (; size != 0; p += 2, q++, size--) {
- *q = p[0];
- }
-#else
- for (; size != 0; p += 2, q++, size--) {
- *q = p[1];
- }
-#endif
+ for (int64 i = 0; i < size; ++i) {
+ dst[i] = bfloat16(src[i]);
+ }
}
void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) {
const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
uint16_t* q = reinterpret_cast<uint16_t*>(dst);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
- for (; size != 0; p++, q += 2, size--) {
- q[0] = *p;
- q[1] = 0;
+ for (; size != 0; p++, q += 2, size--) {
+ q[0] = *p;
+ q[1] = 0;
}
-#else
- for (; size != 0; p++, q += 2, size--) {
- q[0] = 0;
- q[1] = *p;
- }
+#else
+ for (; size != 0; p++, q += 2, size--) {
+ q[0] = 0;
+ q[1] = *p;
+ }
#endif
}
diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc
index af4e6a4411..a25b764ea2 100644
--- a/tensorflow/core/framework/bfloat16_test.cc
+++ b/tensorflow/core/framework/bfloat16_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/bfloat16.h"
+#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -27,6 +28,97 @@ TEST(Bfloat16Test, Simple) {
EXPECT_EQ(0x4140, a.value);
}
+float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
+ uint32_t low_mantissa) {
+ return bit_cast<float>((sign << 31) + (exponent << 23) +
+ (high_mantissa << 16) + low_mantissa);
+}
+
+struct Bfloat16TestParam {
+ float input;
+ float expected;
+};
+
+class Bfloat16Test : public ::testing::Test,
+ public ::testing::WithParamInterface<Bfloat16TestParam> {};
+
+TEST_P(Bfloat16Test, RoundOrTruncate) {
+ bfloat16 a(GetParam().input);
+ if (std::isnan(GetParam().input)) {
+ EXPECT_TRUE(std::isnan(float(a)));
+ return;
+ }
+ EXPECT_EQ(GetParam().expected, float(a));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ Bfloat16Test_Instantiation, Bfloat16Test,
+ ::testing::Values(
+ // More than half.
+ Bfloat16TestParam{
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b1111010111000011),
+ BinaryToFloat(0, 0b10000000, 0b1001001, 0b0000000000000000)},
+
+ Bfloat16TestParam{
+ BinaryToFloat(1, 0b10000000, 0b1001000, 0b1111010111000011),
+ BinaryToFloat(1, 0b10000000, 0b1001001, 0b0000000000000000)},
+
+ // Exact half.
+ Bfloat16TestParam{
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
+
+ // NaN stays at NaN.
+ Bfloat16TestParam{
+ BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000001),
+ BinaryToFloat(0, 0b11111111, 0b1000000, 0b0000000000000000)},
+
+ // NaN stays at NaN -- no exponents overflow.
+ Bfloat16TestParam{
+ BinaryToFloat(0, 0b11111111, 0b1111111, 0b1111111111111111),
+ BinaryToFloat(0, 0b11111111, 0b1000000, 0b0000000000000000)},
+
+ // More than half, round to an odd number.
+ Bfloat16TestParam{
+ BinaryToFloat(1, 0b10000000, 0b1001000, 0b1100000000000000),
+ BinaryToFloat(1, 0b10000000, 0b1001001, 0b0000000000000000)},
+
+ // Less than half, truncate.
+ Bfloat16TestParam{
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000),
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
+
+ // Less than half, truncate.
+ Bfloat16TestParam{
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b0100000000000000),
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
+
+ // Exact at half, but result is already even.
+ Bfloat16TestParam{
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
+
+ // Denormal values.
+ Bfloat16TestParam{
+ BinaryToFloat(0, 0b00000000, 0b1001000, 0b1000000000000000),
+ BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000)},
+ Bfloat16TestParam{
+ BinaryToFloat(0, 0b00000000, 0b1111111, 0b1100000000000000),
+ BinaryToFloat(0, 0b00000001, 0b0000000, 0b0000000000000000)}));
+TEST(Bfloat16Test, RoundWithFractionOverflow) {
+ // Still works with fraction overflow -- round to 4./
+ //
+ // Input 3.9960938:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // 0 1 0 0 0 0 0 0 1 1 1 1 1 1 1 1100000000000000
+ //
+ // Should round to 4.0:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0
+ bfloat16 a(3.9960938f);
+ EXPECT_EQ(4.0, float(a));
+}
+
TEST(Bfloat16Test, Conversion) {
float a[100];
for (int i = 0; i < 100; ++i) {
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
index a630bee38d..d005de2af1 100644
--- a/tensorflow/core/framework/numeric_types.h
+++ b/tensorflow/core/framework/numeric_types.h
@@ -44,29 +44,262 @@ typedef Eigen::QUInt16 quint16;
// see framework/bfloat16.h for description.
struct bfloat16 {
EIGEN_DEVICE_FUNC bfloat16() {}
- EIGEN_DEVICE_FUNC explicit bfloat16(const float v) {
- const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
+
+ explicit EIGEN_DEVICE_FUNC bfloat16(float v) {
+ uint32_t input;
+ memcpy(&input, &v, sizeof(uint32_t));
+
+ if ((~input & 0x7f800000) == 0 && (input & 0x007fffff) != 0) {
+ // If the value is a NaN, squash it to a qNaN with msb of fraction set,
+ // this makes sure after truncation we don't end up with an inf.
+ //
+ // qNaN magic: All exponent bits set + most significant bit of fraction
+ // set.
+ value = 0x7fc0;
+ } else {
+ // Fast rounding algorithm that rounds a half value to nearest even. This
+ // reduces expected error when we convert a large number of floats. Here
+ // is how it works:
+ //
+ // Definitions:
+ // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
+ // with the following tags:
+ //
+ // Sign | Exp (8 bits) | Frac (23 bits)
+ // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
+ //
+ // S: Sign bit.
+ // E: Exponent bits.
+ // F: First 6 bits of fraction.
+ // L: Least significant bit of resulting bfloat16 if we truncate away the
+ // rest of the float32. This is also the 7th bit of fraction
+ // R: Rounding bit, 8th bit of fraction.
+ // T: Sticky bits, rest of fraction, 15 bits.
+ //
+ // To round half to nearest even, there are 3 cases where we want to round
+ // down (simply truncate the result of the bits away, which consists of
+ // rounding bit and sticky bits) and two cases where we want to round up
+ // (truncate then add one to the result).
+ //
+ // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
+ // 1s) as the rounding bias, adds the rounding bias to the input, then
+ // truncates the last 16 bits away.
+ //
+ // To understand how it works, we can analyze this algorithm case by case:
+ //
+ // 1. L = 0, R = 0:
+ // Expect: round down, this is less than half value.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input may create any carry, depending on
+ // whether there is any value set to 1 in T bits.
+ // - R may be set to 1 if there is a carry.
+ // - L remains 0.
+ // - Note that this case also handles Inf and -Inf, where all fraction
+ // bits, including L, R and Ts are all 0. The output remains Inf after
+ // this algorithm.
+ //
+ // 2. L = 1, R = 0:
+ // Expect: round down, this is less than half value.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 1 = 0x8000
+ // - Adding rounding bias to input doesn't change sticky bits but
+ // adds 1 to rounding bit.
+ // - L remains 1.
+ //
+ // 3. L = 0, R = 1, all of T are 0:
+ // Expect: round down, this is exactly at half, the result is already
+ // even (L=0).
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input sets all sticky bits to 1, but
+ // doesn't create a carry.
+ // - R remains 1.
+ // - L remains 0.
+ //
+ // 4. L = 1, R = 1:
+ // Expect: round up, this is exactly at half, the result needs to be
+ // round to the next even number.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 1 = 0x8000
+ // - Adding rounding bias to input doesn't change sticky bits, but
+ // creates a carry from rounding bit.
+ // - The carry sets L to 0, creates another carry bit and propagate
+ // forward to F bits.
+ // - If all the F bits are 1, a carry then propagates to the exponent
+ // bits, which then creates the minimum value with the next exponent
+ // value. Note that we won't have the case where exponents are all 1,
+ // since that's either a NaN (handled in the other if condition) or inf
+ // (handled in case 1).
+ //
+ // 5. L = 0, R = 1, any of T is 1:
+ // Expect: round up, this is greater than half.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input creates a carry from sticky bits,
+ // sets rounding bit to 0, then create another carry.
+ // - The second carry sets L to 1.
+ //
+ // Examples:
+ //
+ // Exact half value that is already even:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
+ //
+ // This falls into case 3. We truncate the rest of 16 bits and no
+ // carry is created into F and L:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
+ //
+ // Exact half value, round to next even number:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // which then propagates into L and F:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
+ //
+ //
+ // Max denormal value round to min normal value:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // propagate into L and F, which then propagates into exponent
+ // bits:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
+ //
+ // Max normal value round to Inf:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // propagate into L and F, which then propagates into exponent
+ // bits:
+ //
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
+ //
+ //
+ // Least significant bit of resulting bfloat.
+ uint32_t lsb = (input >> 16) & 1;
+ uint32_t rounding_bias = 0x7fff + lsb;
+ input += rounding_bias;
+ value = static_cast<uint16_t>(input >> 16);
+ }
+ }
+
+ template <class T>
+ explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
+ : bfloat16(static_cast<float>(val)) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
+ float result;
+
+ uint16_t* q = reinterpret_cast<uint16_t*>(&result);
+
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
- value = p[0];
+ q[0] = value;
+ q[1] = 0;
#else
- value = p[1];
+ q[0] = 0;
+ q[1] = value;
#endif
+ return result;
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator bool() const {
+ return static_cast<bool>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator Eigen::half() const {
+ return static_cast<Eigen::half>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator short() const {
+ return static_cast<short>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator int() const {
+ return static_cast<int>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator char() const {
+ return static_cast<char>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator signed char() const {
+ return static_cast<signed char>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator unsigned char() const {
+ return static_cast<unsigned char>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator unsigned int() const {
+ return static_cast<unsigned int>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator unsigned long() const {
+ return static_cast<unsigned long>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator unsigned long long() const {
+ return static_cast<unsigned long long>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator long long() const {
+ return static_cast<long long>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator double() const {
+ return static_cast<double>(float(*this));
}
uint16_t value;
};
+inline bool operator==(const bfloat16 a, const bfloat16 b) {
+ return a.value == b.value;
+}
+
+inline bool operator!=(const bfloat16 a, const bfloat16 b) {
+ return a.value != b.value;
+}
+
} // end namespace tensorflow
namespace Eigen {
template <>
struct NumTraits<tensorflow::bfloat16> : GenericNumTraits<uint16_t> {};
-EIGEN_STRONG_INLINE bool operator==(const tensorflow::bfloat16 a,
- const tensorflow::bfloat16 b) {
- return a.value == b.value;
-}
-
+using ::tensorflow::operator==;
+using ::tensorflow::operator!=;
} // namespace Eigen
#ifdef COMPILER_MSVC
diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc
index 1e93e9be09..d84d5431e9 100644
--- a/tensorflow/core/framework/op_gen_lib.cc
+++ b/tensorflow/core/framework/op_gen_lib.cc
@@ -84,7 +84,7 @@ static bool SplitAt(char split_ch, StringPiece* orig,
auto pos = orig->find(split_ch);
if (pos == StringPiece::npos) {
*before_split = *orig;
- orig->clear();
+ *orig = StringPiece();
return false;
} else {
*before_split = orig->substr(0, pos);
@@ -236,7 +236,7 @@ string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
unescaped.push_back('\n');
}
strings::StrAppend(&unescaped, line);
- line.clear();
+ line = StringPiece();
}
// Escape what we extracted and then output it in quotes.
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
index c31ab18cc1..4bb37e4f6e 100644
--- a/tensorflow/core/framework/register_types.h
+++ b/tensorflow/core/framework/register_types.h
@@ -87,7 +87,8 @@ limitations under the License.
#elif defined(__ANDROID_TYPES_FULL__)
-// Only half, float, int32, int64, bool, and quantized types are supported.
+// Only string, half, float, int32, int64, bool, and quantized types
+// supported.
#define TF_CALL_float(m) m(float)
#define TF_CALL_double(m)
#define TF_CALL_int32(m) m(::tensorflow::int32)
@@ -96,7 +97,7 @@ limitations under the License.
#define TF_CALL_int16(m)
#define TF_CALL_int8(m)
-#define TF_CALL_string(m)
+#define TF_CALL_string(m) m(string)
#define TF_CALL_resource(m)
#define TF_CALL_variant(m)
#define TF_CALL_complex64(m)
diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc
index a9e4c1cfb1..90756a4f2f 100644
--- a/tensorflow/core/framework/rendezvous.cc
+++ b/tensorflow/core/framework/rendezvous.cc
@@ -36,15 +36,15 @@ namespace tensorflow {
Rendezvous::ParsedKey& Rendezvous::ParsedKey::operator=(const ParsedKey& b) {
const char* b_base = b.buf_.data();
buf_ = b.buf_;
- src_device.set(buf_.data() + (b.src_device.data() - b_base),
- b.src_device.size());
+ src_device = StringPiece(buf_.data() + (b.src_device.data() - b_base),
+ b.src_device.size());
src = b.src;
src_incarnation = b.src_incarnation;
- dst_device.set(buf_.data() + (b.dst_device.data() - b_base),
- b.dst_device.size());
+ dst_device = StringPiece(buf_.data() + (b.dst_device.data() - b_base),
+ b.dst_device.size());
dst = b.dst;
- edge_name.set(buf_.data() + (b.edge_name.data() - b_base),
- b.edge_name.size());
+ edge_name = StringPiece(buf_.data() + (b.edge_name.data() - b_base),
+ b.edge_name.size());
return *this;
}
@@ -104,9 +104,9 @@ Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) {
strings::HexStringToUint64(parts[1], &out->src_incarnation) &&
DeviceNameUtils::ParseFullName(parts[2], &out->dst) &&
!parts[3].empty()) {
- out->src_device.set(parts[0].data(), parts[0].size());
- out->dst_device.set(parts[2].data(), parts[2].size());
- out->edge_name.set(parts[3].data(), parts[3].size());
+ out->src_device = StringPiece(parts[0].data(), parts[0].size());
+ out->dst_device = StringPiece(parts[2].data(), parts[2].size());
+ out->edge_name = StringPiece(parts[3].data(), parts[3].size());
return Status::OK();
}
return errors::InvalidArgument("Invalid rendezvous key: ", key);
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 753cb260e5..2ee409768b 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -68,7 +68,8 @@ class GraphConstructor {
Options(const GraphConstructorOptions& in) // NOLINT(runtime/explicit)
: allow_internal_ops(in.allow_internal_ops),
expect_device_spec(in.expect_device_spec),
- importing(false) {}
+ importing(false),
+ validate_colocation_constraints(false) {}
Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit)
: allow_internal_ops(false),
expect_device_spec(false),
@@ -81,7 +82,8 @@ class GraphConstructor {
control_dependencies(in.control_dependencies),
return_tensors(in.return_tensors),
return_nodes(in.return_nodes),
- importing(true) {}
+ importing(true),
+ validate_colocation_constraints(in.validate_colocation_constraints) {}
bool allow_internal_ops;
bool expect_device_spec;
@@ -103,6 +105,7 @@ class GraphConstructor {
// applicable to ConvertGraphDefToGraph as well, so make an attempt to
// remove this.
bool importing;
+ bool validate_colocation_constraints;
};
typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
@@ -492,7 +495,8 @@ Status GraphConstructor::InitFromEdges() {
Status GraphConstructor::ValidateColocationConstraints(
const NodeDef& node_def) {
- if (!opts_.importing) return Status::OK();
+ if (!opts_.validate_colocation_constraints || !opts_.importing)
+ return Status::OK();
const auto iter = node_def.attr().find(kColocationAttrName);
if (iter == node_def.attr().end()) return Status::OK();
for (const string& c : iter->second.list().s()) {
diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h
index 416c0ee9ae..4b418b8622 100644
--- a/tensorflow/core/graph/graph_constructor.h
+++ b/tensorflow/core/graph/graph_constructor.h
@@ -119,6 +119,9 @@ struct ImportGraphDefOptions {
// TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need.
std::vector<string> return_nodes;
+ // If true, checks that all colocation constraints are nodes in the GraphDef.
+ bool validate_colocation_constraints = true;
+
// TODO(ashankar): Enable handling of GraphDefs produced by newer binaries
// with ops that are not defined in the binary calling ImportGraphDef.
// Similar to the producer_op_list argument to import_graph_def in the
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index cd541c7d86..893826da3e 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -2978,5 +2978,20 @@ versions {
EXPECT_EQ(17, refiner.graph_def_version());
}
+TEST_F(GraphConstructorTest, ImportGraphDef_ValidateColationConstraints) {
+ GraphDef def;
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+ "node { name: 'A' op: 'TestInput' attr { key: '_class' value { list { "
+ "s:'loc:@missing' } } } }",
+ &def));
+ ImportGraphDefOptions options;
+ // TODO(yaozhang): Extend ExpectError to check error type and use ExpectError
+ // and ExpectOK to replace the code below.
+ Status s = ImportGraphDef(options, def, &graph_, nullptr);
+ EXPECT_TRUE(errors::IsInvalidArgument(s)) << s;
+ options.validate_colocation_constraints = false;
+ TF_EXPECT_OK(ImportGraphDef(options, def, &graph_, nullptr));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 44322a2d8c..11654a6a28 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -50,13 +50,9 @@ template <typename Handle>
struct HandleToObject {};
template <>
struct HandleToObject<ShapeHandle> {
- typedef TensorShapeProto Object;
+ typedef ShapeHandle Object;
- static TensorShapeProto Unknown() {
- TensorShapeProto result;
- result.set_unknown_rank(true);
- return result;
- }
+ static ShapeHandle Unknown() { return ShapeHandle(); }
};
template <>
@@ -67,13 +63,24 @@ struct HandleToObject<DimensionHandle> {
};
template <typename Handle>
-struct Processor {
+struct Processor {};
+
+template <>
+struct Processor<ShapeHandle> {
// Extract the shape or dim denoted by the handle.
- void ExtractValue(Handle /*t1*/,
- typename HandleToObject<Handle>::Object* result) {}
+ void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; }
// Merge the shapes or dims.
- Status Merge(Handle /*t1*/, Handle /*t2*/,
- typename HandleToObject<Handle>::Object* result) {
+ Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) {
+ if (InferenceContext::RankKnown(*result)) {
+ // The result was initialized in a previous merge to a shape of known
+ // rank, make sure we preserve that information.
+ return Status::OK();
+ }
+ if (InferenceContext::RankKnown(h1)) {
+ *result = h1;
+ } else {
+ *result = h2;
+ }
return Status::OK();
}
};
@@ -101,24 +108,37 @@ struct Processor<DimensionHandle> {
if (dim1 >= 0 && dim2 >= 0) {
CHECK_EQ(dim1, dim2);
- *result = dim1;
+ return RefineDim(dim1, result);
} else if (dim1 >= 0 && dim2 < 0) {
- *result = dim1;
+ return RefineDim(dim1, result);
} else if (dim1 < 0 && dim2 >= 0) {
- *result = dim2;
+ return RefineDim(dim2, result);
} else if (dim1 < -1) {
- *result = dim1;
+ return RefineDim(dim1, result);
} else if (dim2 < -1) {
- *result = dim2;
+ return RefineDim(dim2, result);
} else {
CHECK_EQ(dim1, dim2);
CHECK_EQ(-1, dim1);
- *result = -1;
+ return RefineDim(-1, result);
}
return Status::OK();
}
private:
+ Status RefineDim(int64 dim, int64* result) {
+ if (*result >= 0) {
+ if (!(*result == dim || dim < 0)) {
+ return errors::InvalidArgument("Inconsistent dimensions detected");
+ }
+ } else if (dim >= 0) {
+ *result = dim;
+ } else if (dim < *result) {
+ *result = dim;
+ }
+ return Status::OK();
+ }
+
int64 counter = 2;
};
@@ -354,18 +374,17 @@ class SymbolicShapeManager {
return dims_.Merge(d1, d2);
}
- int64 Value(DimensionHandle d) { return dims_.GetMergedValue(d); }
-
void AsTensorProperties(const ShapeHandle& shape, const DataType& type,
- InferenceContext* ctx,
OpInfo::TensorProperties* properties) {
properties->set_dtype(type);
- if (!ctx->RankKnown(shape)) {
+ ShapeHandle actual_shape = shapes_.GetMergedValue(shape);
+ if (!InferenceContext::RankKnown(actual_shape)) {
properties->mutable_shape()->set_unknown_rank(true);
} else {
- for (int j = 0; j < ctx->Rank(shape); ++j) {
- shape_inference::DimensionHandle dim = ctx->Dim(shape, j);
- int64 d = Value(dim);
+ for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
+ shape_inference::DimensionHandle dim =
+ InferenceContext::DimKnownRank(actual_shape, j);
+ int64 d = dims_.GetMergedValue(dim);
properties->mutable_shape()->add_dim()->set_size(d);
}
}
@@ -447,6 +466,11 @@ Status GraphProperties::InferStatically() {
shape_refiner.set_disable_constant_propagation(true);
shape_refiner.set_function_library_for_shape_inference(&function_library);
ImportGraphDefOptions options;
+ // Graph optimization happens at the late stage of graph execution,
+ // when colocation constraints are already validated previously and
+ // the device placement of nodes has also completed, so there
+ // is no need to validate colocation constraints again.
+ options.validate_colocation_constraints = false;
Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
TF_RETURN_IF_ERROR(s);
@@ -472,41 +496,6 @@ Status GraphProperties::InferStatically() {
}
}
}
-
- // Infer output shape for Restore op.
- if (node->op_def().name() == "Restore" ||
- node->op_def().name() == "RestoreV2" ||
- node->op_def().name() == "RestoreSlice") {
- auto ctx = shape_refiner.GetContext(node);
- for (const Edge* out_edge : node->out_edges()) {
- const Node* output = out_edge->dst();
- int output_idx = out_edge->src_output();
- if (output_idx < 0) {
- continue;
- }
- if (!ctx->FullyDefined(ctx->output(output_idx)) &&
- output->op_def().name() == "Assign") {
- if (!output->attrs().Find("validate_shape") ||
- !output->attrs().Find("validate_shape")->b()) {
- continue;
- }
- auto output_ctx = shape_refiner.GetContext(output);
- if (output_ctx->FullyDefined(output_ctx->output(0))) {
- ctx->set_output(output_idx, output_ctx->output(0));
- output_ctx->MergeInput(1, output_ctx->output(0));
- } else {
- const Node* var;
- TF_CHECK_OK(node->input_node(0, &var));
- if (node->IsVariable()) {
- auto var_ctx = shape_refiner.GetContext(var);
- CHECK(var_ctx->FullyDefined(var_ctx->output(0)));
- ctx->set_output(output_idx, var_ctx->output(0));
- output_ctx->MergeInput(1, var_ctx->output(0));
- }
- }
- }
- }
- }
}
// Propagate the initial shapes of Enter nodes manually (the Enter shape
@@ -688,7 +677,7 @@ Status GraphProperties::InferStatically() {
input_properties.resize(ctx->num_inputs());
for (int i = 0; i < ctx->num_inputs(); ++i) {
shape_manager.AsTensorProperties(ctx->input(i), node->input_type(i),
- ctx, &input_properties[i]);
+ &input_properties[i]);
}
for (const auto& edge : node->in_edges()) {
if (!edge->src()->IsConstant()) {
@@ -715,7 +704,7 @@ Status GraphProperties::InferStatically() {
output_properties.resize(ctx->num_outputs());
for (int i = 0; i < ctx->num_outputs(); ++i) {
shape_manager.AsTensorProperties(ctx->output(i), node->output_type(i),
- ctx, &output_properties[i]);
+ &output_properties[i]);
}
}
}
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index a33cdacc09..f785f627e1 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/grappler/inputs/utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -295,10 +296,9 @@ TEST_F(GraphPropertiesTest, Queues) {
ASSERT_EQ(1, props2.size());
EXPECT_EQ("float: [3,7]", PropToString(props2[0]));
- // The dequeue3 op shape is unknown.
const auto props3 = properties.GetOutputProperties("Dequeue3");
ASSERT_EQ(1, props3.size());
- EXPECT_EQ("float: ?", PropToString(props3[0]));
+ EXPECT_EQ("float: [3,7]", PropToString(props3[0]));
// The dequeue3 op shape is unknown. The square2 op shape is known. Verify
// that we merge the 2 properly to determine the shape of the data coming out
@@ -677,8 +677,8 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output var =
- ops::Variable(s.WithOpName("var"), TensorShape(), DataType::DT_FLOAT);
+ Output var = ops::Variable(s.WithOpName("var"), PartialTensorShape(),
+ DataType::DT_FLOAT);
Output var2 = ops::Variable(s.WithOpName("var2"), TensorShape({128, 256}),
DataType::DT_FLOAT);
Output filename =
@@ -784,6 +784,30 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
EXPECT_EQ(shape_f.dim(1).size(), shape_a.dim(1).size());
}
+TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
+ Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
+ Output c = ops::Const(s.WithOpName("c").ColocateWith(a), 3.0f, {1});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ // Create a graph with node a removed (say by some graph optimization
+ // pass), noting that node c is colocated with a. This is fine as it
+ // is in the late stage of graph execution, the colocation constraints have
+ // been validated previously and the device placement of nodes has completed.
+ GraphDef optimized_graph;
+ for (const auto& node : item.graph.node()) {
+ if (node.name() != "a") {
+ *optimized_graph.add_node() = node;
+ }
+ }
+ item.graph.Swap(&optimized_graph);
+ GraphProperties properties(item);
+ // This function should return OK, since it doesn't validate the colocation
+ // constraints internally.
+ TF_EXPECT_OK(properties.InferStatically());
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index d5625ae58f..2ab3a9144c 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -154,6 +154,16 @@ Status VirtualScheduler::Init() {
name_to_node[node->name()] = node;
}
+ // TODO(dyoon): Instead of identifying _Send node here manually, add _Send
+ // to _Recv as control dependency when creating GrapplerItem.
+ std::unordered_map<string, const NodeDef*> name_to_send;
+ for (const auto& node : graph.node()) {
+ if (node.op() == "_Send") {
+ const auto& attr = node.attr();
+ name_to_send[attr.at("tensor_name").s()] = &node;
+ }
+ }
+
// To reuse _Recv ops.
std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescritorHash,
RecvNodeDescriptorEqual>
@@ -164,7 +174,17 @@ Status VirtualScheduler::Init() {
for (const auto* curr_node : nodes) {
auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
const string curr_node_device = DeviceName(curr_node);
- for (const string& input_node_name : curr_node->input()) {
+ std::vector<string> inputs;
+ if (IsRecv(*curr_node)) {
+ const auto& attr = curr_node->attr();
+ const NodeDef* send = name_to_send[attr.at("tensor_name").s()];
+ inputs = {send->name()};
+ } else {
+ for (const string& input : curr_node->input()) {
+ inputs.push_back(input);
+ }
+ }
+ for (const string& input_node_name : inputs) {
// Note that input_node_name may be in <prefix><node_name>:<port_num>
// format, where <prefix> (e.g., "^" for control dependency) and
// ":<port_num>" may be omitted. NodeName() extracts only the node_name.
@@ -219,7 +239,7 @@ Status VirtualScheduler::Init() {
// Default case: node without inputs are ready at time 0.
const bool has_no_inputs = curr_node->input().empty();
- if (given_as_feed || has_no_inputs) {
+ if (!IsRecv(*curr_node) && (given_as_feed || has_no_inputs)) {
curr_node_state.time_ready = Costs::Duration();
ready_nodes_->AddNode(curr_node);
VLOG(3) << "Added ready node: " << curr_node->name();
@@ -254,7 +274,10 @@ void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
// This method is called when NodeState is created and adds input and output
// properties for a few exceptional cases that GraphProperties cannot provide
// input/output properties.
- if (IsSend(*node) || IsRecv(*node)) {
+ if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
+ // _Send and _Recv ops created from VirtualScheduler have kAttrInputSrc
+ // attr; normal _Send and _Recv ops (from the input graph) do not have that
+ // attr.
auto& node_state = node_map_[node];
auto& inputs = node_state.input_properties;
auto& outputs = node_state.output_properties;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index d291a04308..40548b5a07 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -265,6 +265,127 @@ class VirtualSchedulerTest : public ::testing::Test {
dependency_["z4"] = {"bn"};
}
+ void CreateGrapplerItemWithSendRecv() {
+ const string gdef_ascii = R"EOF(
+node {
+ name: "Const"
+ op: "Const"
+ device: "/job:localhost/replica:0/task:0/device:CPU:0"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 3.1415
+ }
+ }
+ }
+}
+node {
+ name: "Send"
+ op: "_Send"
+ input: "Const"
+ device: "/job:localhost/replica:0/task:0/device:CPU:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "client_terminated"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "recv_device"
+ value {
+ s: "/job:localhost/replica:0/task:0/device:CPU:0"
+ }
+ }
+ attr {
+ key: "send_device"
+ value {
+ s: "/job:localhost/replica:0/task:0/device:CPU:0"
+ }
+ }
+ attr {
+ key: "send_device_incarnation"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "tensor_name"
+ value {
+ s: "test"
+ }
+ }
+}
+node {
+ name: "Recv"
+ op: "_Recv"
+ device: "/job:localhost/replica:0/task:0/device:CPU:0"
+ attr {
+ key: "client_terminated"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "recv_device"
+ value {
+ s: "/job:localhost/replica:0/task:0/device:CPU:0"
+ }
+ }
+ attr {
+ key: "send_device"
+ value {
+ s: "/job:localhost/replica:0/task:0/device:CPU:0"
+ }
+ }
+ attr {
+ key: "send_device_incarnation"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "tensor_name"
+ value {
+ s: "test"
+ }
+ }
+ attr {
+ key: "tensor_type"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+library {
+}
+versions {
+ producer: 24
+}
+ )EOF";
+
+ grappler_item_.reset(new GrapplerItem);
+ CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
+ &grappler_item_->graph));
+ grappler_item_->id = "test_graph";
+ grappler_item_->fetch = {"Recv"};
+ }
+
// A simple while loop
void CreateGrapplerItemWithLoop() {
// Test graph produced in python using:
@@ -743,6 +864,7 @@ versions {
do {
OpContext op_context = scheduler_->GetCurrNode();
ops_executed[op_context.name] = op_context;
+ std::cout << op_context.name << std::endl;
Costs node_costs = SimplePredictCosts(op_context);
@@ -1530,5 +1652,54 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
EXPECT_EQ(get_output_size(recv_op_names[-1]), 4);
EXPECT_EQ(get_output_size(send_op_names[-1]), 4);
}
+
+TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {
+ // Init.
+ CreateGrapplerItemWithSendRecv();
+ InitScheduler();
+
+ // Run the scheduler.
+ auto ops_executed = RunScheduler("");
+
+ EXPECT_GT(ops_executed.count("Const"), 0);
+ EXPECT_GT(ops_executed.count("Send"), 0);
+ EXPECT_GT(ops_executed.count("Recv"), 0);
+}
+
+TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) {
+ // Init.
+ CreateGrapplerItemWithSendRecv();
+ // Change Recv node's device so that Send and Recv are placed on different
+ // devices.
+ auto& graph = grappler_item_->graph;
+ const string recv_device = kCPU1;
+ for (int i = 0; i < graph.node_size(); i++) {
+ auto* node = graph.mutable_node(i);
+ if (node->name() == "Recv") {
+ node->set_device(recv_device);
+ auto* attr = node->mutable_attr();
+ (*attr)["recv_device"].set_s(recv_device);
+ } else if (node->name() == "Send") {
+ auto* attr = node->mutable_attr();
+ (*attr)["recv_device"].set_s(recv_device);
+ }
+ }
+ InitScheduler();
+
+ // Run the scheduler.
+ auto ops_executed = RunScheduler("");
+
+ // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops.
+ EXPECT_GT(ops_executed.count("Const"), 0);
+ EXPECT_GT(ops_executed.count("Send"), 0);
+ EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/"
+ "task_0/cpu_0_to_/job_localhost"
+ "/replica_0/task_0/cpu_1"),
+ 0);
+ EXPECT_GT(ops_executed.count(
+ "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"),
+ 0);
+ EXPECT_GT(ops_executed.count("Recv"), 0);
+}
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc
index 94412eb198..844a1fa328 100644
--- a/tensorflow/core/grappler/grappler_item.cc
+++ b/tensorflow/core/grappler/grappler_item.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
@@ -117,8 +118,13 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
bool* ill_formed) {
*ill_formed = false;
std::unordered_map<string, const NodeDef*> name_to_node;
+ std::unordered_map<string, const NodeDef*> name_to_send;
for (const auto& node : graph.node()) {
name_to_node[node.name()] = &node;
+ if (node.op() == "_Send") {
+ const auto& attr = node.attr();
+ name_to_send[attr.at("tensor_name").s()] = &node;
+ }
}
std::vector<const NodeDef*> queue;
@@ -150,6 +156,15 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
}
queue.push_back(in);
}
+ if (node->op() == "_Recv") {
+ const auto& attr = node->attr();
+ const NodeDef* send = name_to_send[attr.at("tensor_name").s()];
+ if (send) {
+ queue.push_back(send);
+ }
+ // Subgraph after partitioning may have either _Send or _Recv, not both.
+ // So, we do not set ill_formed for missing _Send.
+ }
}
return result;
}
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 669d02815c..54004a5e07 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -112,6 +112,7 @@ tf_cc_test(
deps = [
":constant_folding",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 38af7170b5..c0518736fe 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -707,7 +707,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
node_map->AddOutput(new_transpose->name(), new_cast->name());
new_nodes->push_back(new_transpose);
- new_nodes->push_back(new_cast);
// Add frame dependencies that the original node might have had.
AddFrameControlDeps(node, {new_transpose, new_cast},
new_transpose->input(0), {new_transpose},
@@ -799,7 +798,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
scale_tensor.tensor_shape().dim_size() == 0) {
// Create new node `scaled_weights`.
NodeDef* scaled_weights = graph_def->add_node();
- scaled_weights->set_name(weights->name() + "_scaled");
+ scaled_weights->set_name(weights->name() + "_scaled_" +
+ conv->name());
scaled_weights->set_op("Mul");
scaled_weights->set_device(weights->device());
(*scaled_weights->mutable_attr())["T"] =
@@ -837,8 +837,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
}
- if (node->input_size() > 0 && IsAggregate(*node) &&
- !node_map->GetOutputs(node->name()).empty()) {
+ if (node->input_size() > 0 && IsAggregate(*node)) {
// Discard aggregate nodes with a single input.
if (node->input_size() == 1) {
return node->input(0);
@@ -859,7 +858,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
break;
}
}
- if (all_equal) {
+ if (all_equal && node_map->GetNode(node->name() + "_const") == nullptr) {
// 1. Create constant node with value N.
const int N = node->input_size();
const auto type = GetDataTypeFromAttr(*node, "T");
@@ -885,7 +884,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
new_mul_node->set_device(node->device());
SetDataTypeToAttr(type, "T", new_mul_node);
node_map->AddNode(new_mul_node->name(), new_mul_node);
- new_nodes->push_back(new_mul_node);
new_mul_node->add_input(new_const_node->name());
node_map->AddOutput(new_const_node->name(), new_mul_node->name());
new_mul_node->add_input(node->input(0));
@@ -902,7 +900,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
// where all the inputs are Mul nodes. This pattern occurs frequently in
// regularization terms for the gradients during training.
if (node->input_size() > 1 && IsAggregate(*node) &&
- !node_map->GetOutputs(node->name()).empty()) {
+ node_map->GetNode(node->name() + "_hoist") == nullptr) {
// Determine the set of common factors if the input nodes are all Mul nodes.
std::set<string> common_factors;
int i = 0;
@@ -950,7 +948,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
new_mul_node->set_name(new_mul_node->name() + "_hoist");
new_mul_node->set_input(0, common_factor);
new_mul_node->set_input(1, new_add_node->name());
- new_nodes->push_back(new_mul_node);
node_map->AddNode(new_mul_node->name(), new_mul_node);
}
}
@@ -1015,8 +1012,9 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
// Fold Conj into Transpose or ConjugateTranspose.
- if (node->op() == "Conj" || node->op() == "Transpose" ||
- node->op() == "ConjugateTranspose") {
+ if ((node->op() == "Conj" || node->op() == "Transpose" ||
+ node->op() == "ConjugateTranspose") &&
+ node_map->GetNode(node->name() + "_fused") == nullptr) {
const NodeDef* input = node_map->GetNode(node->input(0));
const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
const NodeDef* conj_op = node->op() == "Conj" ? node : input;
@@ -1049,10 +1047,14 @@ namespace {
template <class T>
class SetVector {
public:
- void PushBack(const T& value) {
- CHECK(!Exists(value)) << "Value " << value << " is already in the set.";
- set_.insert(value);
+ // Returns false if value already existed in the set, true otherwise.
+ bool PushBack(const T& value) {
+ if (!set_.insert(value).second) {
+ VLOG(2) << "Value " << value << " is already in the set.";
+ return false;
+ }
vector_.push_back(value);
+ return true;
}
T PopBack() {
@@ -1093,6 +1095,11 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(
}
if (NodeName(simplified_tensor) != node->name()) {
+ // Always consider simplified_tensor for further optimizations.
+ const NodeDef* simplified_node = node_map.GetNode(simplified_tensor);
+ if (simplified_node != nullptr) {
+ nodes_to_simplify.PushBack(simplified_node);
+ }
// When `node` is simplifed to another node rather than in-place, the
// consumers of `node` are already redirected to `simplified_tensor`.
// Re-push the consumers into `nodes_to_simplify` for further
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 9f471302c7..4fcbb0120e 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -38,8 +38,8 @@ TEST_F(ArithmeticOptimizerTest, NoOp) {
ArithmeticOptimizer optimizer;
GraphDef output;
- Status s = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(s);
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
EXPECT_EQ(item.graph.node_size(), output.node_size());
for (int i = 0; i < item.graph.node_size(); ++i) {
@@ -66,6 +66,10 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
EXPECT_EQ(2, output.node_size());
const NodeDef& new_c1 = output.node(0);
@@ -91,6 +95,10 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
EXPECT_EQ(4, output.node_size());
const NodeDef& new_c1 = output.node(0);
@@ -146,13 +154,17 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) {
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
EXPECT_EQ(6, output.node_size());
EXPECT_EQ("squeeze", output.node(5).input(0));
EXPECT_EQ("c", output.node(2).input(0));
}
-TEST_F(ArithmeticOptimizerTest, SimplifyReplaceTrivialSums) {
+TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
Output add = ops::Add(s.WithOpName("add"), x, x);
@@ -165,6 +177,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyReplaceTrivialSums) {
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
EXPECT_EQ(5, output.node_size());
const NodeDef& new_const = output.node(3);
@@ -178,7 +194,61 @@ TEST_F(ArithmeticOptimizerTest, SimplifyReplaceTrivialSums) {
EXPECT_EQ("add_mul", new_id.input(0));
}
-TEST_F(ArithmeticOptimizerTest, SimplifyHoistFactor) {
+TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
+ // Test case from b/69059093.
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output p = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({10, 10}));
+ Output add = ops::Add(s.WithOpName("Add"), p, p);
+ Output add1 = ops::Add(s.WithOpName("Add_1"), p, p);
+ Output add4 = ops::Add(s.WithOpName("Add_4"), add, add1);
+ Output add5 = ops::Add(s.WithOpName("Add_5"), add, add1);
+ Output add6 = ops::Add(s.WithOpName("Add_6"), add4, add5);
+ Output id = ops::Identity(s.WithOpName("id"), add6);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ ArithmeticOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(11, output.node_size());
+ const NodeDef& new_id = output.node(4);
+ EXPECT_EQ("id", new_id.name());
+ EXPECT_EQ("Add_6_mul", new_id.input(0));
+
+ // Add4 and add5 get deduped, and we rewrite each of the 3 remaining add nodes
+ // of the form Add(x,x) into Mul(Const(2), x).
+ const NodeDef& new_add_4_const = output.node(5);
+ EXPECT_EQ("Add_4_const", new_add_4_const.name());
+ EXPECT_EQ("^Add", new_add_4_const.input(0));
+ const NodeDef& new_add_4_mul = output.node(6);
+ EXPECT_EQ("Add_4_mul", new_add_4_mul.name());
+ EXPECT_EQ("Add_4_const", new_add_4_mul.input(0));
+ EXPECT_EQ("Add_mul", new_add_4_mul.input(1));
+
+ const NodeDef& new_add_6_const = output.node(7);
+ EXPECT_EQ("Add_6_const", new_add_6_const.name());
+ EXPECT_EQ("^Add_4_mul", new_add_6_const.input(0));
+ const NodeDef& new_add_6_mul = output.node(8);
+ EXPECT_EQ("Add_6_mul", new_add_6_mul.name());
+ EXPECT_EQ("Add_6_const", new_add_6_mul.input(0));
+ EXPECT_EQ("Add_4_mul", new_add_6_mul.input(1));
+
+ const NodeDef& new_add_const = output.node(9);
+ EXPECT_EQ("Add_const", new_add_const.name());
+ EXPECT_EQ("^Placeholder", new_add_const.input(0));
+ const NodeDef& new_add_mul = output.node(10);
+ EXPECT_EQ("Add_mul", new_add_mul.name());
+ EXPECT_EQ("Add_const", new_add_mul.input(0));
+ EXPECT_EQ("Placeholder", new_add_mul.input(1));
+}
+
+TEST_F(ArithmeticOptimizerTest, HoistFactor) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
@@ -195,6 +265,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyHoistFactor) {
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
EXPECT_EQ(9, output.node_size());
const NodeDef& new_add = output.node(8);
@@ -225,6 +299,10 @@ TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
EXPECT_EQ(7, output.node_size());
EXPECT_EQ("trans_fused", output.node(6).name());
@@ -272,6 +350,10 @@ TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
EXPECT_EQ(7, output.node_size());
EXPECT_EQ("conj_fused", output.node(6).name());
@@ -304,6 +386,10 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
EXPECT_EQ(7, output.node_size());
EXPECT_EQ("matmul_fused", output.node(6).name());
@@ -801,7 +887,7 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
CHECK_NOTNULL(node_map.GetNode("Transpose_uint8"));
const NodeDef* cast_node = CHECK_NOTNULL(node_map.GetNode("Cast_new"));
const NodeDef* weights_node =
- CHECK_NOTNULL(node_map.GetNode("weights_scaled"));
+ CHECK_NOTNULL(node_map.GetNode("weights_scaled_Conv2D"));
const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D"));
EXPECT_EQ(output.node_size(), 7);
@@ -811,6 +897,50 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
EXPECT_EQ(conv_node->input(1), weights_node->name());
}
+TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
+ // This unit test exercises optimization of folding mul into conv for
+ // multiple nodes in the graph.
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0");
+
+ GrapplerItem item;
+ Output conv[2];
+
+ for (int i = 0; i < 2; ++i) {
+ Output inputs =
+ ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 3, 28, 28}));
+ Output mul = ops::Mul(s, inputs, ops::Const(s, 1.0f / 255.0f));
+ Output weights = ops::Const(s.WithOpName("weights"),
+ Input::Initializer(127.0f, {5, 5, 3, 16}));
+ conv[i] = ops::Conv2D(s, mul, weights, {1, 1, 1, 1}, "VALID",
+ ops::Conv2D::DataFormat("NCHW"));
+ }
+ Output outputs = ops::Add(s.WithOpName("outputs"), conv[0], conv[1]);
+
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
+
+ item.graph = output;
+ TF_EXPECT_OK(
+ ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output));
+
+ item.graph = output;
+ TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+
+ NodeMap node_map(&output);
+ const NodeDef* weights_node =
+ CHECK_NOTNULL(node_map.GetNode("weights_scaled_Conv2D"));
+ const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D"));
+
+ const NodeDef* weights_node_1 =
+ CHECK_NOTNULL(node_map.GetNode("weights_scaled_Conv2D_1"));
+ const NodeDef* conv_node_1 = CHECK_NOTNULL(node_map.GetNode("Conv2D_1"));
+ EXPECT_EQ(conv_node->input(1), weights_node->name());
+ EXPECT_EQ(conv_node_1->input(1), weights_node_1->name());
+}
+
TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index cb02314183..02a732b092 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/public/version.h"
+#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
namespace grappler {
@@ -95,11 +96,15 @@ class DeviceSimple : public DeviceBase {
};
} // namespace
-ConstantFolding::ConstantFolding(DeviceBase* cpu_device)
- : cpu_device_(cpu_device) {
+ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
+ DeviceBase* cpu_device)
+ : opt_level_(opt_level), cpu_device_(cpu_device) {
resource_mgr_.reset(new ResourceMgr());
}
+ConstantFolding::ConstantFolding(DeviceBase* cpu_device)
+ : ConstantFolding(RewriterConfig::ON, cpu_device) {}
+
// static
string ConstantFolding::AddControlDependency(const string& input_name,
GraphDef* graph,
@@ -281,6 +286,149 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
return Status::OK();
}
+bool ShapesEqual(const TensorShapeProto& shape1,
+ const TensorShapeProto& shape2) {
+ if (shape1.unknown_rank() || shape2.unknown_rank()) {
+ return false;
+ }
+ if (shape1.dim_size() != shape2.dim_size()) {
+ return false;
+ }
+ for (int i = 0; i < shape1.dim_size(); ++i) {
+ if (shape1.dim(i).size() != shape2.dim(i).size()) {
+ return false;
+ }
+ }
+ return true;
+}
+
+namespace {
+bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
+ BCast::Vec* shape, int64* min_id) {
+ if (shape_node.op() == "Shape") {
+ const std::vector<OpInfo::TensorProperties>& prop1 =
+ properties.GetInputProperties(shape_node.name());
+ if (prop1.size() != 1) {
+ return false;
+ }
+ const TensorShapeProto& shp = prop1[0].shape();
+ if (shp.unknown_rank()) {
+ return false;
+ }
+ for (const auto& dim : shp.dim()) {
+ shape->push_back(dim.size());
+ *min_id = std::min<int64>(*min_id, dim.size());
+ }
+ } else {
+ const TensorProto& raw_val = shape_node.attr().at("value").tensor();
+ if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) {
+ return false;
+ }
+ Tensor value(raw_val.dtype(), raw_val.tensor_shape());
+ if (!value.FromProto(raw_val)) {
+ return false;
+ }
+ for (int j = 0; j < value.NumElements(); ++j) {
+ if (raw_val.dtype() == DT_INT64) {
+ shape->push_back(value.vec<int64>()(j));
+ } else {
+ shape->push_back(value.vec<int>()(j));
+ }
+ }
+ }
+ return true;
+}
+} // namespace
+
+Status ConstantFolding::MaterializeConstants(
+ const GrapplerItem& item, const GraphProperties& properties) {
+ const int node_count = graph_.node_size();
+ for (int i = 0; i < node_count; ++i) {
+ NodeDef& node = *graph_.mutable_node(i);
+ const string& op = node.op();
+ if (op != "BroadcastGradientArgs") {
+ continue;
+ }
+ const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
+ const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
+ if (shape_node1 == nullptr ||
+ (shape_node1->op() != "Shape" && shape_node1->op() != "Const") ||
+ shape_node2 == nullptr ||
+ (shape_node2->op() != "Shape" && shape_node2->op() != "Const")) {
+ continue;
+ }
+ int64 min_id = 0;
+ BCast::Vec shape1;
+ if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) {
+ continue;
+ }
+ BCast::Vec shape2;
+ if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) {
+ continue;
+ }
+ // A value of -1 means we don't known anything about the dimension. Replace
+ // the -1 values with unique dimension ids since we don't want two '-1'
+ // dimensions to be considered equal.
+ for (auto& id : shape1) {
+ if (id == -1) {
+ id = --min_id;
+ }
+ }
+ for (auto& id : shape2) {
+ if (id == -1) {
+ id = --min_id;
+ }
+ }
+ BCast bcast(shape1, shape2);
+ if (!bcast.IsValid()) {
+ continue;
+ }
+ BCast::Vec reduce_dims[2];
+ reduce_dims[0] = bcast.grad_x_reduce_idx();
+ reduce_dims[1] = bcast.grad_y_reduce_idx();
+
+ const DataType type = node.attr().at("T").type();
+ NodeDef* out[2];
+ for (int j = 0; j < 2; ++j) {
+ if (!reduce_dims[j].empty()) {
+ // This is the case when a tensor dimension 1 is matched against an
+ // unknown dimension. The unknown dimension could also be equal to 1, in
+ // which case there would be no reduction.
+ out[j] = nullptr;
+ } else {
+ Tensor value(type, TensorShape({0}));
+ string const_name = AddPrefixToNodeName(
+ strings::StrCat(node.name(), "-", j), kConstantFoldingConst);
+ out[j] = node_map_->GetNode(const_name);
+ if (!out[j]) {
+ out[j] = graph_.add_node();
+ *out[j] = CreateNodeDef(const_name, TensorValue(&value));
+ out[j]->set_device(node.device());
+ node_map_->AddNode(const_name, out[j]);
+ string ctrl_dep =
+ AddControlDependency(node.name(), &graph_, node_map_.get());
+ *out[j]->add_input() = ctrl_dep;
+ node_map_->AddOutput(NodeName(ctrl_dep), const_name);
+ }
+ }
+ }
+
+ auto outputs = node_map_->GetOutputs(node.name());
+ for (const auto& output : outputs) {
+ for (int k = 0; k < output->input_size(); ++k) {
+ int port;
+ string node_name = ParseNodeName(output->input(k), &port);
+ if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
+ *output->mutable_input(k) = out[port]->name();
+ node_map_->UpdateInput(output->name(), node_name, out[port]->name());
+ }
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
bool ConstantFolding::IsFoldable(const NodeDef& node) const {
// Folding not applicable to ops with no inputs.
if (node.input().empty()) {
@@ -921,23 +1069,23 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
}
GraphProperties properties(item);
+ Status s = properties.InferStatically();
bool has_feed = !item.feed.empty();
- if (!has_feed) {
+
+ if (!has_feed && s.ok()) {
// Only use static shape information when there is no feed in the
// graph. That's because it's possible to feed a placeholder with a tensor
// of any shape, which could make the static information inconsistent with
// the shapes actually fed.
- Status s = properties.InferStatically();
- if (!s.ok()) {
- VLOG(1) << "Failed to infer graph shapes: " << s;
- } else {
- TF_RETURN_IF_ERROR(MaterializeShapes(item, properties));
- }
+ TF_RETURN_IF_ERROR(MaterializeShapes(item, properties));
+ }
+ if (opt_level_ == RewriterConfig::AGGRESSIVE && s.ok()) {
+ TF_RETURN_IF_ERROR(MaterializeConstants(item, properties));
}
TF_RETURN_IF_ERROR(FoldGraph(output));
- if (!has_feed) {
+ if (!has_feed && s.ok()) {
TF_RETURN_IF_ERROR(SimplifyGraph(output, properties));
}
return Status::OK();
@@ -956,12 +1104,14 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
GrapplerItem item_to_optimize = item;
*output = item.graph;
+ int64 node_count;
do {
graph_.Swap(output);
item_to_optimize.graph = graph_;
*output = GraphDef();
+ node_count = graph_.node_size();
TF_RETURN_IF_ERROR(RunOptimizationPass(cluster, item_to_optimize, output));
- } while (output->node_size() < graph_.node_size());
+ } while (output->node_size() != node_count);
*output->mutable_library() = item.graph.library();
*output->mutable_versions() = item.graph.versions();
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 30d778789a..dd988f336c 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace grappler {
@@ -37,6 +38,7 @@ class ConstantFolding : public GraphOptimizer {
NodeMap* node_map);
ConstantFolding(DeviceBase* cpu_device);
+ ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device);
~ConstantFolding() override {}
@@ -51,7 +53,8 @@ class ConstantFolding : public GraphOptimizer {
private:
Status MaterializeShapes(const GrapplerItem& item,
const GraphProperties& properties);
-
+ Status MaterializeConstants(const GrapplerItem& item,
+ const GraphProperties& properties);
bool IsFoldable(const NodeDef& node) const;
Status EvaluateNode(const NodeDef& node,
@@ -74,6 +77,7 @@ class ConstantFolding : public GraphOptimizer {
GraphDef* output);
// Points to an externally provided device or to owned_device_;
+ RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;
std::unique_ptr<DeviceBase> owned_device_;
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index a1dee6d2fb..43f84b1ddf 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
+#include "tensorflow/cc/ops/array_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@@ -838,6 +839,85 @@ TEST_F(ConstantFoldingTest, Packing) {
// size needed to naively encode 1000 floats folded twice).
EXPECT_GT(8000, output.ByteSizeLong());
}
+
+TEST_F(ConstantFoldingTest, ConstantMaterialization) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a =
+ ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
+ ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
+ Output b = ops::Square(s.WithOpName("b"), a);
+ Output c = ops::Mul(s.WithOpName("c"), a, b);
+ Output d = ops::Shape(s.WithOpName("d"), a);
+ Output e = ops::Shape(s.WithOpName("e"), b);
+
+ auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
+ Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
+ Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
+
+ Output g = ops::Placeholder(s.WithOpName("g"), DT_FLOAT,
+ ops::Placeholder::Shape(PartialTensorShape({1})));
+ Output h = ops::Shape(s.WithOpName("h"), g);
+ auto i = ops::internal::BroadcastGradientArgs(s.WithOpName("i"), d, h);
+ Output p1 = ops::Identity(s.WithOpName("p1"), i.r0);
+ Output p2 = ops::Identity(s.WithOpName("p2"), i.r1);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ConstantFolding fold(RewriterConfig::AGGRESSIVE, nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = fold.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ // Run a second time to make sure the optimization is idempotent.
+ item.graph.Swap(&output);
+ status = fold.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ int found = 0;
+ for (const auto& node : output.node()) {
+ if (node.name() == "o1") {
+ ++found;
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("ConstantFolding/f-0", node.input(0));
+ } else if (node.name() == "o2") {
+ ++found;
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("ConstantFolding/f-1", node.input(0));
+ } else if (node.name() == "ConstantFolding/f-0") {
+ ++found;
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("^f", node.input(0));
+ EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
+ .num_elements());
+ } else if (node.name() == "ConstantFolding/f-1") {
+ ++found;
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("^f", node.input(0));
+ EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
+ .num_elements());
+ } else if (node.name() == "p1") {
+ ++found;
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("ConstantFolding/i-0", node.input(0));
+ } else if (node.name() == "p2") {
+ ++found;
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("i:1", node.input(0));
+ } else if (node.name() == "ConstantFolding/i-0") {
+ ++found;
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("^i", node.input(0));
+ EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
+ .num_elements());
+ }
+ }
+ EXPECT_EQ(7, found);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index a9875c06d8..6204a81f80 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -64,8 +64,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimizers.push_back(std::unique_ptr<GraphOptimizer>(new ModelPruner()));
}
if (cfg_.constant_folding() != RewriterConfig::OFF) {
- optimizers.push_back(
- std::unique_ptr<GraphOptimizer>(new ConstantFolding(cpu_device_)));
+ optimizers.push_back(std::unique_ptr<GraphOptimizer>(
+ new ConstantFolding(cfg_.constant_folding(), cpu_device_)));
}
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 54be02b5f8..3a5028cfe3 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -45,7 +45,6 @@ NodeDef* NodeMap::GetNode(const string& name) const {
string node_name = NodeName(name);
auto it = nodes_.find(node_name);
if (it == nodes_.end()) {
- LOG(WARNING) << "Node " << node_name << " is not in the graph.";
return nullptr;
}
return it->second;
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index a5c62fef17..f27b00c1b1 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4443,6 +4443,15 @@ filegroup(
"fill_functor.h",
"function_ops.cc",
"gather_functor.h",
+ "gather_nd_op.cc",
+ "gather_nd_op.h",
+ "gather_nd_op_cpu_impl.h",
+ "gather_nd_op_cpu_impl_0.cc",
+ "gather_nd_op_cpu_impl_1.cc",
+ "gather_nd_op_cpu_impl_2.cc",
+ "gather_nd_op_cpu_impl_3.cc",
+ "gather_nd_op_cpu_impl_4.cc",
+ "gather_nd_op_cpu_impl_5.cc",
"gather_op.cc",
"identity_n_op.cc",
"identity_n_op.h",
@@ -4536,6 +4545,10 @@ filegroup(
"fused_batch_norm_op.h",
"gemm_functors.h",
"image_resizer_state.h",
+ "initializable_lookup_table.h",
+ "lookup_table_init_op.h",
+ "lookup_table_op.h",
+ "lookup_util.h",
"maxpooling_op.h",
"mfcc.h",
"mfcc_dct.h",
@@ -4552,6 +4565,7 @@ filegroup(
"resize_nearest_neighbor_op.h",
"reverse_op.h",
"save_restore_tensor.h",
+ "segment_reduction_ops.h",
"softplus_op.h",
"softsign_op.h",
"spacetobatch_functor.h",
@@ -4601,6 +4615,8 @@ filegroup(
"cwise_op_div.cc",
"cwise_op_equal_to_1.cc",
"cwise_op_equal_to_2.cc",
+ "cwise_op_not_equal_to_1.cc",
+ "cwise_op_not_equal_to_2.cc",
"cwise_op_exp.cc",
"cwise_op_floor.cc",
"cwise_op_floor_div.cc",
@@ -4642,6 +4658,7 @@ filegroup(
"encode_wav_op.cc",
"fake_quant_ops.cc",
"fifo_queue.cc",
+ "fifo_queue_op.cc",
"fused_batch_norm_op.cc",
"population_count_op.cc",
"population_count_op.h",
@@ -4665,7 +4682,11 @@ filegroup(
"depthtospace_op.cc",
"dynamic_stitch_op.cc",
"in_topk_op.cc",
+ "initializable_lookup_table.cc",
"logging_ops.cc",
+ "lookup_table_init_op.cc",
+ "lookup_table_op.cc",
+ "lookup_util.cc",
"lrn_op.cc",
"maxpooling_op.cc",
"mfcc.cc",
@@ -4700,12 +4721,15 @@ filegroup(
"save_op.cc",
"save_restore_tensor.cc",
"save_restore_v2_ops.cc",
+ "segment_reduction_ops.cc",
"session_ops.cc",
"softplus_op.cc",
"softsign_op.cc",
"spacetobatch_functor.cc",
"spacetobatch_op.cc",
"spacetodepth_op.cc",
+ "sparse_fill_empty_rows_op.cc",
+ "sparse_reshape_op.cc",
"sparse_to_dense_op.cc",
"spectrogram.cc",
"spectrogram_op.cc",
@@ -4728,6 +4752,7 @@ filegroup(
"training_ops.cc",
"transpose_functor_cpu.cc",
"transpose_op.cc",
+ "unique_op.cc",
"warn_about_ints.cc",
"where_op.cc",
"xent_op.cc",
@@ -6241,8 +6266,11 @@ tf_kernel_library(
srcs = ["summary_kernels.cc"],
deps = [
":summary_interface",
+ "//tensorflow/contrib/tensorboard/db:summary_db_writer",
"//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//tensorflow/core:summary_ops_op_lib",
+ "//tensorflow/core/lib/db:sqlite",
],
)
diff --git a/tensorflow/core/kernels/batch_dataset_op.cc b/tensorflow/core/kernels/batch_dataset_op.cc
index 2e52ad39f8..6a5fd17a9e 100644
--- a/tensorflow/core/kernels/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/batch_dataset_op.cc
@@ -143,9 +143,13 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
// Each row of `batch_elements` is a tuple of tensors from the
// input iterator.
std::vector<std::vector<Tensor>> batch_elements;
- batch_elements.reserve(dataset()->batch_size_);
{
mutex_lock l(mu_);
+ if (!input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ batch_elements.reserve(dataset()->batch_size_);
*end_of_sequence = false;
for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
++i) {
@@ -154,6 +158,8 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
end_of_sequence));
if (!*end_of_sequence) {
batch_elements.emplace_back(std::move(batch_element_tuple));
+ } else {
+ input_impl_.reset();
}
}
}
@@ -194,14 +200,23 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ if (!input_impl_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impl_empty"), ""));
+ } else {
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ }
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ if (!reader->Contains(full_name("input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ } else {
+ input_impl_.reset();
+ }
return Status::OK();
}
diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc
index 258ce15456..b0bec0c5dc 100644
--- a/tensorflow/core/kernels/concat_lib_cpu.cc
+++ b/tensorflow/core/kernels/concat_lib_cpu.cc
@@ -74,11 +74,14 @@ REGISTER(qint16)
REGISTER(qint32)
REGISTER(bfloat16)
-#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION)
-// Primarily used for SavedModel support on mobile.
+#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \
+ !defined(__ANDROID_TYPES_FULL__)
+// Primarily used for SavedModel support on mobile. Registering it here only if
+// __ANDROID_TYPES_FULL__ is not defined, as that already register strings
REGISTER(string);
#endif // defined(IS_MOBILE_PLATFORM) &&
- // !defined(SUPPORT_SELECTIVE_REGISTRATION)
+ // !defined(SUPPORT_SELECTIVE_REGISTRATION) &&
+ // !defined(__ANDROID_TYPES_FULL__)
#ifdef TENSORFLOW_USE_SYCL
template <typename T>
diff --git a/tensorflow/core/kernels/concatenate_dataset_op.cc b/tensorflow/core/kernels/concatenate_dataset_op.cc
index 711c234129..c3bd89c479 100644
--- a/tensorflow/core/kernels/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/concatenate_dataset_op.cc
@@ -104,6 +104,10 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
+ if (!input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
while (i_ < 2) {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@@ -140,7 +144,9 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
} else if (i_ == 2) {
input_impl_.reset();
}
- TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ if (input_impl_) {
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ }
return Status::OK();
}
diff --git a/tensorflow/core/kernels/dataset.cc b/tensorflow/core/kernels/dataset.cc
index 0414875a5d..fcfa2956f7 100644
--- a/tensorflow/core/kernels/dataset.cc
+++ b/tensorflow/core/kernels/dataset.cc
@@ -126,7 +126,6 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
MakeDataset(ctx, input, another_input, output);
}
-const char IteratorBase::kIteratorExhausted[] = "ITERATOR_EXHAUSTED";
const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] =
"_DATASET_GRAPH_OUTPUT_NODE";
diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h
index 4a42ac80c3..aa4f436b39 100644
--- a/tensorflow/core/kernels/dataset.h
+++ b/tensorflow/core/kernels/dataset.h
@@ -306,27 +306,14 @@ class IteratorBase {
// Saves the state of this iterator.
virtual Status Save(IteratorStateWriter* writer) {
- if (is_exhausted_) {
- LOG(INFO) << "Iterator exhausted.";
- return writer->WriteScalar(kIteratorExhausted, kIteratorExhausted);
- } else {
- return SaveInternal(writer);
- }
+ return SaveInternal(writer);
}
// Restores the state of this iterator.
virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
- if (reader->Contains(kIteratorExhausted)) {
- LOG(INFO) << "Iterator exhausted. Nothing to restore.";
- is_exhausted_ = true;
- return Status::OK();
- } else {
- return RestoreInternal(ctx, reader);
- }
+ return RestoreInternal(ctx, reader);
}
- static const char kIteratorExhausted[];
-
protected:
// This is needed so that sub-classes of IteratorBase can call
// `SaveInternal` on their parent iterators, e.g., in
@@ -354,8 +341,6 @@ class IteratorBase {
IteratorStateReader* reader) {
return errors::Unimplemented("RestoreInternal");
}
-
- bool is_exhausted_ = false; // Whether the iterator has been exhausted.
};
// Represents a (potentially infinite) range of outputs, where each
@@ -491,10 +476,6 @@ class DatasetIterator : public IteratorBase {
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final {
port::Tracing::TraceMe activity(params_.prefix);
- if (is_exhausted_) {
- *end_of_sequence = true;
- return Status::OK();
- }
return GetNextInternal(ctx, out_tensors, end_of_sequence);
}
diff --git a/tensorflow/core/kernels/fake_quant_ops_functor.h b/tensorflow/core/kernels/fake_quant_ops_functor.h
index b41b22d634..7aaad6e6c7 100644
--- a/tensorflow/core/kernels/fake_quant_ops_functor.h
+++ b/tensorflow/core/kernels/fake_quant_ops_functor.h
@@ -132,7 +132,7 @@ struct FakeQuantWithMinMaxVarsFunctor {
const float max_val = max();
// If min and max are both zero, we should just return zero.
if (min_val == 0.0f && max_val == 0.0f) {
- outputs.setZero();
+ outputs.device(d) = outputs.constant(0.0f);
return;
}
float nudged_min, nudged_max, nudged_scale;
@@ -163,8 +163,8 @@ struct FakeQuantWithMinMaxVarsGradientFunctor {
// If min and max are both zero, we propagate everything to inputs.
if (min_val == 0.0f && max_val == 0.0f) {
backprops_wrt_input.device(d) = gradients;
- backprop_wrt_min.setZero();
- backprop_wrt_max.setZero();
+ backprop_wrt_min.device(d) = backprop_wrt_min.constant(0.0f);
+ backprop_wrt_max.device(d) = backprop_wrt_max.constant(0.0f);
return;
}
float nudged_min, nudged_max, nudged_scale;
@@ -205,7 +205,8 @@ struct FakeQuantWithMinMaxVarsPerChannelFunctor {
const float max_val = max(i);
// If min and max are both zero, we should just return zero.
if (min_val == 0.0f && max_val == 0.0f) {
- outputs.chip<1>(i).setZero();
+ auto chip = outputs.chip<1>(i);
+ chip.device(d) = chip.constant(0.0f);
continue;
}
float nudged_min, nudged_max, nudged_scale;
@@ -242,8 +243,10 @@ struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor {
// If min and max are both zero, we propagate everything to inputs.
if (min_val == 0.0f && max_val == 0.0f) {
backprops_wrt_input.chip<1>(i).device(d) = gradients_chip;
- backprop_wrt_min.chip<0>(i).setZero();
- backprop_wrt_max.chip<0>(i).setZero();
+ auto min_chip = backprop_wrt_min.chip<0>(i);
+ auto max_chip = backprop_wrt_max.chip<0>(i);
+ min_chip.device(d) = min_chip.constant(0.0f);
+ max_chip.device(d) = max_chip.constant(0.0f);
continue;
}
float nudged_min, nudged_max, nudged_scale;
diff --git a/tensorflow/core/kernels/immutable_constant_op_test.cc b/tensorflow/core/kernels/immutable_constant_op_test.cc
index b318c9c79a..b3814331ee 100644
--- a/tensorflow/core/kernels/immutable_constant_op_test.cc
+++ b/tensorflow/core/kernels/immutable_constant_op_test.cc
@@ -147,8 +147,8 @@ Status CreateTempFile(Env* env, float value, uint64 size, string* filename) {
std::unique_ptr<WritableFile> file;
TF_RETURN_IF_ERROR(env->NewWritableFile(*filename, &file));
for (uint64 i = 0; i < size; ++i) {
- StringPiece sp;
- sp.set(&value, sizeof(value));
+ StringPiece sp(static_cast<char*>(static_cast<void*>(&value)),
+ sizeof(value));
TF_RETURN_IF_ERROR(file->Append(sp));
}
TF_RETURN_IF_ERROR(file->Close());
diff --git a/tensorflow/core/kernels/range_dataset_op.cc b/tensorflow/core/kernels/range_dataset_op.cc
index 7adfcc4f8d..e7ae840fc7 100644
--- a/tensorflow/core/kernels/range_dataset_op.cc
+++ b/tensorflow/core/kernels/range_dataset_op.cc
@@ -99,7 +99,6 @@ class RangeDatasetOp : public DatasetOpKernel {
if ((dataset()->step_ > 0 && next_ >= dataset()->stop_) ||
(dataset()->step_ < 0 && next_ <= dataset()->stop_)) {
*end_of_sequence = true;
- is_exhausted_ = true;
return Status::OK();
}
Tensor value_tensor(cpu_allocator(), DT_INT64, {});
diff --git a/tensorflow/core/kernels/reader_dataset_ops.cc b/tensorflow/core/kernels/reader_dataset_ops.cc
index 39ef92a5de..c08e42be1d 100644
--- a/tensorflow/core/kernels/reader_dataset_ops.cc
+++ b/tensorflow/core/kernels/reader_dataset_ops.cc
@@ -402,7 +402,6 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
// Iteration ends when there are no more files to process.
if (current_file_index_ == dataset()->filenames_.size()) {
*end_of_sequence = true;
- is_exhausted_ = true;
return Status::OK();
}
diff --git a/tensorflow/core/kernels/repeat_dataset_op.cc b/tensorflow/core/kernels/repeat_dataset_op.cc
index 9813e99a70..0167b9ea64 100644
--- a/tensorflow/core/kernels/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/repeat_dataset_op.cc
@@ -95,6 +95,15 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ return Status::OK();
+ }
+ Status RestoreInternal(OpKernelContext* ctx,
+ IteratorStateReader* reader) override {
+ return Status::OK();
+ }
};
class FiniteIterator : public DatasetIterator<Dataset> {
@@ -108,6 +117,10 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
+ if (!input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
while (i_ < dataset()->count_) {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@@ -118,7 +131,6 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
input_impl_ = dataset()->input_->MakeIterator(prefix());
}
*end_of_sequence = true;
- is_exhausted_ = true;
input_impl_.reset();
return Status::OK();
}
@@ -127,7 +139,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
- TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ if (!input_impl_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impl_empty"), ""));
+ } else {
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ }
return Status::OK();
}
@@ -135,7 +152,11 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
- TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ if (!reader->Contains(full_name("input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ } else {
+ input_impl_.reset();
+ }
return Status::OK();
}
@@ -183,6 +204,29 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
} while (true);
}
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (input_impl_)
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ else
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("uninitialized"), ""));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(OpKernelContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (reader->Contains(full_name("uninitialized"))) {
+ input_impl_.reset();
+ } else {
+ input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ }
+ return Status::OK();
+ }
+
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/shuffle_dataset_op.cc b/tensorflow/core/kernels/shuffle_dataset_op.cc
index 2146ba2aa1..dd0ab57e9d 100644
--- a/tensorflow/core/kernels/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/shuffle_dataset_op.cc
@@ -105,8 +105,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_);
int64 start_micros = ctx->env()->NowMicros();
int64 num_log_entries = 0;
- while (!end_of_input_sequence_ &&
- buffer_.size() < dataset()->buffer_size_) {
+ while (input_impl_ && buffer_.size() < dataset()->buffer_size_) {
if (ctx->env()->NowMicros() >
((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
num_log_entries++;
@@ -114,9 +113,10 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
<< buffer_.size() << " of " << dataset()->buffer_size_;
}
std::vector<Tensor> input_element;
+ bool end_of_input_sequence;
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
- &end_of_input_sequence_));
- if (!end_of_input_sequence_) {
+ &end_of_input_sequence));
+ if (!end_of_input_sequence) {
buffer_.emplace_back(std::move(input_element));
} else {
input_impl_.reset();
@@ -135,7 +135,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
std::swap(buffer_[index], buffer_.back());
buffer_.pop_back();
} else {
- DCHECK(end_of_input_sequence_);
+ DCHECK(input_impl_ == nullptr);
*end_of_sequence = true;
}
return Status::OK();
@@ -148,11 +148,11 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
// Save the tensors in the buffer.
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("buffer_size"), buffer_.size()));
- for (int i = 0; i < buffer_.size(); i++) {
+ for (size_t i = 0; i < buffer_.size(); i++) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("buffer_", i, "_size")),
buffer_[i].size()));
- for (int j = 0; j < buffer_[i].size(); j++) {
+ for (size_t j = 0; j < buffer_[i].size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat("buffer_", i, "_", j)),
buffer_[i][j]));
@@ -165,7 +165,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
// Save input iterator if it hasn't been exhausted else write
// "end_of_input_sequence".
- if (end_of_input_sequence_) {
+ if (!input_impl_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("end_of_input_sequence"), ""));
} else {
@@ -180,10 +180,15 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
buffer_.clear();
// Restore the buffer.
- int64 buffer_size;
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("buffer_size"), &buffer_size));
- for (int i = 0; i < buffer_size; i++) {
+ size_t buffer_size;
+ {
+ int64 temp;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("buffer_size"), &temp));
+ buffer_size = static_cast<size_t>(temp);
+ }
+ buffer_.reserve(buffer_size);
+ for (size_t i = 0; i < buffer_size; i++) {
int64 list_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("buffer_", i, "_size")), &list_size));
@@ -205,7 +210,6 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
input_impl_ = dataset()->input_->MakeIterator(prefix());
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
- end_of_input_sequence_ = true;
input_impl_.reset();
}
return Status::OK();
@@ -230,7 +234,6 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
mutex mu_;
std::vector<std::vector<Tensor>> buffer_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
- bool end_of_input_sequence_ GUARDED_BY(mu_) = false;
const int64 seed_ GUARDED_BY(mu_);
const int64 seed2_ GUARDED_BY(mu_);
random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/skip_dataset_op.cc b/tensorflow/core/kernels/skip_dataset_op.cc
index 52a6116a7c..7ee945dd4c 100644
--- a/tensorflow/core/kernels/skip_dataset_op.cc
+++ b/tensorflow/core/kernels/skip_dataset_op.cc
@@ -35,14 +35,14 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
int64 count;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
- *output = new Dataset(count, input);
+ *output = new Dataset(ctx, count, input);
}
private:
- class Dataset : public DatasetBase {
+ class Dataset : public GraphDatasetBase {
public:
- Dataset(int64 count, const DatasetBase* input)
- : count_(count), input_(input) {
+ Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
+ : GraphDatasetBase(ctx), count_(count), input_(input) {
input_->Ref();
}
@@ -71,6 +71,18 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
string DebugString() override { return "SkipDatasetOp::Dataset"; }
+ protected:
+ Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node));
+ Node* count = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
+ TF_RETURN_IF_ERROR(
+ b->AddDataset(this, {input_graph_node, count}, output));
+ return Status::OK();
+ }
+
private:
class EmptyIterator : public DatasetIterator<Dataset> {
public:
@@ -82,6 +94,16 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ return Status::OK();
+ }
+
+ Status RestoreInternal(OpKernelContext* ctx,
+ IteratorStateReader* reader) override {
+ return Status::OK();
+ }
};
class FiniteIterator : public DatasetIterator<Dataset> {
@@ -96,6 +118,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
+ if (!input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
// Keep calling GetNext(). TODO(vrv): Figure out a way to
// skip records without reading, perhaps by adding an
// interface to iterator.
@@ -116,6 +143,34 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
// Return GetNext() on the underlying iterator.
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors,
end_of_sequence));
+ if (*end_of_sequence) {
+ input_impl_.reset();
+ }
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
+ if (input_impl_) {
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ } else {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impl_empty"), ""));
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(OpKernelContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
+ if (!reader->Contains(full_name("input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ } else {
+ input_impl_.reset();
+ }
return Status::OK();
}
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 8fc40db3cc..73b6d4cf6a 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -427,6 +427,7 @@ REGISTER_STRIDED_SLICE(bfloat16);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h
index 7d42887426..a39fdff954 100644
--- a/tensorflow/core/kernels/strided_slice_op_impl.h
+++ b/tensorflow/core/kernels/strided_slice_op_impl.h
@@ -283,6 +283,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N_GPU);
TF_CALL_complex64(DECLARE_FOR_N_GPU);
TF_CALL_complex128(DECLARE_FOR_N_GPU);
DECLARE_FOR_N_GPU(int32);
+DECLARE_FOR_N_GPU(int64);
#endif // END GOOGLE_CUDA
TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
@@ -298,6 +299,7 @@ DECLARE_FOR_N_CPU(bfloat16);
TF_CALL_SYCL_PROXY_TYPES(PREVENT_FOR_N_SYCL);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_FOR_N_SYCL);
DECLARE_FOR_N_SYCL(int32);
+DECLARE_FOR_N_SYCL(int64);
#undef DECLARE_FOR_N_SYCL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/summary_interface.cc b/tensorflow/core/kernels/summary_interface.cc
index 313137ae49..cd366f8c13 100644
--- a/tensorflow/core/kernels/summary_interface.cc
+++ b/tensorflow/core/kernels/summary_interface.cc
@@ -257,7 +257,9 @@ class SummaryWriterImpl : public SummaryWriterInterface {
Summary::Value* v = e->mutable_summary()->add_value();
t.AsProtoTensorContent(v->mutable_tensor());
v->set_tag(tag);
- v->mutable_metadata()->ParseFromString(serialized_metadata);
+ if (!serialized_metadata.empty()) {
+ v->mutable_metadata()->ParseFromString(serialized_metadata);
+ }
return WriteEvent(std::move(e));
}
diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc
index cfa707de71..1fe2fc5b66 100644
--- a/tensorflow/core/kernels/summary_kernels.cc
+++ b/tensorflow/core/kernels/summary_kernels.cc
@@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/kernels/summary_interface.h"
+#include "tensorflow/core/lib/db/sqlite.h"
+#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -46,6 +49,32 @@ class CreateSummaryFileWriterOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU),
CreateSummaryFileWriterOp);
+class CreateSummaryDbWriterOp : public OpKernel {
+ public:
+ explicit CreateSummaryDbWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* tmp;
+ OP_REQUIRES_OK(ctx, ctx->input("db_uri", &tmp));
+ const string db_uri = tmp->scalar<string>()();
+ OP_REQUIRES_OK(ctx, ctx->input("experiment_name", &tmp));
+ const string experiment_name = tmp->scalar<string>()();
+ OP_REQUIRES_OK(ctx, ctx->input("run_name", &tmp));
+ const string run_name = tmp->scalar<string>()();
+ OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp));
+ const string user_name = tmp->scalar<string>()();
+ SummaryWriterInterface* s;
+ auto db = Sqlite::Open(db_uri);
+ OP_REQUIRES_OK(ctx, db.status());
+ OP_REQUIRES_OK(
+ ctx, CreateSummaryDbWriter(std::move(db.ValueOrDie()), experiment_name,
+ run_name, user_name, ctx->env(), &s));
+ OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s));
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU),
+ CreateSummaryDbWriterOp);
+
class FlushSummaryWriterOp : public OpKernel {
public:
explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@@ -98,6 +127,27 @@ class WriteSummaryOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("WriteSummary").Device(DEVICE_CPU),
WriteSummaryOp);
+class ImportEventOp : public OpKernel {
+ public:
+ explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ SummaryWriterInterface* s;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
+ core::ScopedUnref unref(s);
+ const Tensor* t;
+ OP_REQUIRES_OK(ctx, ctx->input("event", &t));
+ std::unique_ptr<Event> event{new Event};
+ if (!ParseProtoUnlimited(event.get(), t->scalar<string>()())) {
+ ctx->CtxFailureWithWarning(
+ errors::DataLoss("Bad tf.Event binary proto tensor string"));
+ return;
+ }
+ OP_REQUIRES_OK(ctx, s->WriteEvent(std::move(event)));
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("ImportEvent").Device(DEVICE_CPU), ImportEventOp);
+
class WriteScalarSummaryOp : public OpKernel {
public:
explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
diff --git a/tensorflow/core/kernels/take_dataset_op.cc b/tensorflow/core/kernels/take_dataset_op.cc
index c3f33d663c..fb294a96b1 100644
--- a/tensorflow/core/kernels/take_dataset_op.cc
+++ b/tensorflow/core/kernels/take_dataset_op.cc
@@ -35,14 +35,14 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
// Create a new TakeDatasetOp::Dataset, and return it as the output.
int64 count;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
- *output = new Dataset(count, input);
+ *output = new Dataset(ctx, count, input);
}
private:
- class Dataset : public DatasetBase {
+ class Dataset : public GraphDatasetBase {
public:
- Dataset(int64 count, const DatasetBase* input)
- : count_(count), input_(input) {
+ Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
+ : GraphDatasetBase(ctx), count_(count), input_(input) {
input_->Ref();
}
@@ -72,6 +72,18 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
string DebugString() override { return "TakeDatasetOp::Dataset"; }
+ protected:
+ Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node));
+ Node* count = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
+ TF_RETURN_IF_ERROR(
+ b->AddDataset(this, {input_graph_node, count}, output));
+ return Status::OK();
+ }
+
private:
class EmptyIterator : public DatasetIterator<Dataset> {
public:
@@ -83,6 +95,16 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ return Status::OK();
+ }
+
+ Status RestoreInternal(OpKernelContext* ctx,
+ IteratorStateReader* reader) override {
+ return Status::OK();
+ }
};
class FiniteIterator : public DatasetIterator<Dataset> {
@@ -96,6 +118,10 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
+ if (!input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
while (i_ < dataset()->count_) {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@@ -110,6 +136,31 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
+ if (input_impl_) {
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ } else {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impl_empty"), ""));
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(OpKernelContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
+ if (!reader->Contains(full_name("input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ } else {
+ input_impl_.reset();
+ }
+ return Status::OK();
+ }
+
private:
mutex mu_;
int64 i_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/zip_dataset_op.cc b/tensorflow/core/kernels/zip_dataset_op.cc
index a80b9edbe4..f466c8b268 100644
--- a/tensorflow/core/kernels/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/zip_dataset_op.cc
@@ -35,14 +35,15 @@ class ZipDatasetOp : public DatasetOpKernel {
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
inputs.push_back(input);
}
- *output = new Dataset(inputs);
+ *output = new Dataset(ctx, inputs);
}
private:
- class Dataset : public DatasetBase {
+ class Dataset : public GraphDatasetBase {
public:
- explicit Dataset(const std::vector<DatasetBase*>& inputs)
- : inputs_(inputs) {
+ explicit Dataset(OpKernelContext* ctx,
+ const std::vector<DatasetBase*>& inputs)
+ : GraphDatasetBase(ctx), inputs_(inputs) {
for (const auto& input : inputs_) {
input->Ref();
for (DataType dt : input->output_dtypes()) {
@@ -76,6 +77,21 @@ class ZipDatasetOp : public DatasetOpKernel {
string DebugString() override { return "ZipDatasetOp::Dataset"; }
+ protected:
+ Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ std::vector<NodeBuilder::NodeOut> input_graph_nodes;
+ input_graph_nodes.reserve(inputs_.size());
+ for (const auto& input : inputs_) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(input, &input_node));
+ input_graph_nodes.emplace_back(input_node);
+ }
+ TF_RETURN_IF_ERROR(
+ b->AddDatasetWithInputAsList(this, input_graph_nodes, output));
+ return Status::OK();
+ }
+
private:
class Iterator : public DatasetIterator<Dataset> {
public:
@@ -93,6 +109,10 @@ class ZipDatasetOp : public DatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
+ if (input_impls_.empty()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
out_tensors->clear();
out_tensors->reserve(dataset()->output_dtypes().size());
for (const auto& input_impl : input_impls_) {
@@ -100,12 +120,43 @@ class ZipDatasetOp : public DatasetOpKernel {
TF_RETURN_IF_ERROR(
input_impl->GetNext(ctx, &input_tensors, end_of_sequence));
if (*end_of_sequence) {
- return Status::OK();
+ break;
}
out_tensors->insert(out_tensors->end(), input_tensors.begin(),
input_tensors.end());
}
- *end_of_sequence = false;
+ if (*end_of_sequence) {
+ out_tensors->clear();
+ input_impls_.clear();
+ } else {
+ *end_of_sequence = false;
+ }
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (input_impls_.empty()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impls_empty"), ""));
+ } else {
+ for (auto& input_impl : input_impls_)
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl));
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(OpKernelContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (reader->Contains(full_name("input_impls_empty"))) {
+ input_impls_.clear();
+ } else {
+ DCHECK_EQ(input_impls_.size(), dataset()->inputs_.size());
+ for (auto& input_impl : input_impls_)
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl));
+ }
return Status::OK();
}
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
index 7d258b36c5..94f4a377f1 100644
--- a/tensorflow/core/lib/core/stringpiece.h
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -51,11 +51,6 @@ class StringPiece {
// Create a slice that refers to s[0,strlen(s)-1]
StringPiece(const char* s) : data_(s), size_(strlen(s)) {}
- void set(const void* data, size_t len) {
- data_ = reinterpret_cast<const char*>(data);
- size_ = len;
- }
-
// Return a pointer to the beginning of the referenced data
const char* data() const { return data_; }
@@ -79,12 +74,6 @@ class StringPiece {
return data_[n];
}
- // Change this slice to refer to an empty array
- void clear() {
- data_ = "";
- size_ = 0;
- }
-
// Drop the first "n" bytes from this slice.
void remove_prefix(size_t n) {
assert(n <= size());
diff --git a/tensorflow/core/lib/io/block.cc b/tensorflow/core/lib/io/block.cc
index 1fa26d9147..4c30486cc4 100644
--- a/tensorflow/core/lib/io/block.cc
+++ b/tensorflow/core/lib/io/block.cc
@@ -199,7 +199,7 @@ class Block::Iter : public Iterator {
restart_index_ = num_restarts_;
status_ = errors::DataLoss("bad entry in block");
key_.clear();
- value_.clear();
+ value_ = StringPiece();
}
bool ParseNextKey() {
diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc
index 8509c9a041..240e1454e5 100644
--- a/tensorflow/core/lib/strings/str_util.cc
+++ b/tensorflow/core/lib/strings/str_util.cc
@@ -407,11 +407,11 @@ bool ConsumeNonWhitespace(StringPiece* s, StringPiece* val) {
}
const size_t n = p - s->data();
if (n > 0) {
- val->set(s->data(), n);
+ *val = StringPiece(s->data(), n);
s->remove_prefix(n);
return true;
} else {
- val->clear();
+ *val = StringPiece();
return false;
}
}
diff --git a/tensorflow/core/lib/strings/strcat.cc b/tensorflow/core/lib/strings/strcat.cc
index 46a45a6678..5b1cff486d 100644
--- a/tensorflow/core/lib/strings/strcat.cc
+++ b/tensorflow/core/lib/strings/strcat.cc
@@ -45,7 +45,7 @@ AlphaNum::AlphaNum(Hex hex) {
value >>= 4;
mask >>= 4;
} while (mask != 0);
- piece_.set(writer, end - writer);
+ piece_ = StringPiece(writer, end - writer);
}
// ----------------------------------------------------------------------
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 8b8251f84b..60f67543f1 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -8271,6 +8271,29 @@ op {
}
}
op {
+ name: "DatasetToSingleElement"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "DebugGradientIdentity"
input_arg {
name: "input"
@@ -9249,6 +9272,69 @@ op {
}
}
op {
+ name: "DenseToSparseBatchDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "row_shape"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "DenseToSparseBatchDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "row_shape"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "DenseToSparseSetOperation"
input_arg {
name: "set1"
@@ -9742,6 +9828,18 @@ op {
}
}
op {
+ name: "DeserializeIterator"
+ input_arg {
+ name: "resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "serialized"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
name: "DeserializeManySparse"
input_arg {
name: "serialized_sparse"
@@ -13495,6 +13593,131 @@ op {
}
}
op {
+ name: "GroupByWindowDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "key_func_other_arguments"
+ type_list_attr: "Tkey_func_other_arguments"
+ }
+ input_arg {
+ name: "reduce_func_other_arguments"
+ type_list_attr: "Treduce_func_other_arguments"
+ }
+ input_arg {
+ name: "window_size_func_other_arguments"
+ type_list_attr: "Twindow_size_func_other_arguments"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "key_func"
+ type: "func"
+ }
+ attr {
+ name: "reduce_func"
+ type: "func"
+ }
+ attr {
+ name: "window_size_func"
+ type: "func"
+ }
+ attr {
+ name: "Tkey_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Treduce_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Twindow_size_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "GroupByWindowDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "key_func_other_arguments"
+ type_list_attr: "Tkey_func_other_arguments"
+ }
+ input_arg {
+ name: "reduce_func_other_arguments"
+ type_list_attr: "Treduce_func_other_arguments"
+ }
+ input_arg {
+ name: "window_size_func_other_arguments"
+ type_list_attr: "Twindow_size_func_other_arguments"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "key_func"
+ type: "func"
+ }
+ attr {
+ name: "reduce_func"
+ type: "func"
+ }
+ attr {
+ name: "window_size_func"
+ type: "func"
+ }
+ attr {
+ name: "Tkey_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Treduce_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Twindow_size_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "HSVToRGB"
input_arg {
name: "images"
@@ -13915,6 +14138,53 @@ op {
}
}
op {
+ name: "IgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "IgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Imag"
input_arg {
name: "input"
@@ -15819,6 +16089,50 @@ op {
is_stateful: true
}
op {
+ name: "MapAndBatchDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "num_parallel_batches"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "MapClear"
attr {
name: "capacity"
@@ -20557,6 +20871,54 @@ op {
}
}
op {
+ name: "ParallelInterleaveDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "cycle_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "block_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "sloppy"
+ type: DT_BOOL
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ParallelMapDataset"
input_arg {
name: "input_dataset"
@@ -21309,6 +21671,52 @@ op {
is_stateful: true
}
op {
+ name: "Print"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "data"
+ type_list_attr: "U"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "U"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "message"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "first_n"
+ type: "int"
+ default_value {
+ i: -1
+ }
+ }
+ attr {
+ name: "summarize"
+ type: "int"
+ default_value {
+ i: 3
+ }
+ }
+ is_stateful: true
+}
+op {
name: "PriorityQueue"
output_arg {
name: "handle"
@@ -30147,6 +30555,52 @@ op {
}
}
op {
+ name: "ScanDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "initial_state"
+ type_list_attr: "Tstate"
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Tstate"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ScatterAdd"
input_arg {
name: "ref"
@@ -31862,6 +32316,18 @@ op {
}
}
op {
+ name: "SerializeIterator"
+ input_arg {
+ name: "resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "serialized"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
name: "SerializeManySparse"
input_arg {
name: "sparse_indices"
@@ -37266,6 +37732,38 @@ op {
}
}
op {
+ name: "SqlDataset"
+ input_arg {
+ name: "driver_name"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "data_source_name"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "query"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "Sqrt"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 8f5d8308a3..f512213964 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -141,6 +141,16 @@ count: A scalar representing the number of elements from the `input_dataset`
that should be skipped. If count is -1, skips everything.
)doc");
+REGISTER_OP("IgnoreErrorsDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that contains the elements of `input_dataset` ignoring errors.
+)doc");
+
REGISTER_OP("MapDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@@ -174,6 +184,32 @@ num_parallel_calls: The number of concurrent invocations of `f` that process
elements from `input_dataset` in parallel.
)doc");
+REGISTER_OP("MapAndBatchDataset")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("batch_size: int64")
+ .Input("num_parallel_batches: int64")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that applies `f` to the outputs of `input_dataset` and then
+batches `batch_size` of them.
+
+Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
+to `batch_size * num_parallel_batches` copies of `f` in parallel.
+
+batch_size: A scalar representing the number of elements to accumulate in a
+ batch. It determines the number of concurrent invocations of `f` that process
+ elements from `input_dataset` in parallel.
+num_parallel_batches: A scalar representing the number of batches to create in
+ parallel. Processing multiple batches in parallel benefits workloads prone to
+ stragglers.
+)doc");
+
REGISTER_OP("PrefetchDataset")
.Input("input_dataset: variant")
.Input("buffer_size: int64")
@@ -188,6 +224,21 @@ buffer_size: The maximum number of elements to buffer in an iterator over
this dataset.
)doc");
+REGISTER_OP("ScanDataset")
+ .Input("input_dataset: variant")
+ .Input("initial_state: Tstate")
+ .Input("other_arguments: Targuments")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Tstate: list(type) >= 1")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset successively reduces `f` over the elements of `input_dataset`.
+)doc");
+
REGISTER_OP("FlatMapDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@@ -234,6 +285,59 @@ f: A function mapping elements of `input_dataset`, concatenated with
`output_types` and `output_shapes`.
)doc");
+REGISTER_OP("ParallelInterleaveDataset")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("cycle_length: int64")
+ .Input("block_length: int64")
+ .Input("sloppy: bool")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that applies `f` to the outputs of `input_dataset`.
+
+The resulting dataset is similar to the `InterleaveDataset`, with the exception
+that if retrieving the next value from a dataset would cause the requester to
+block, it will skip that input dataset. This dataset is especially useful
+when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
+allows the training step to proceed so long as some data is available.
+
+!! WARNING !! This dataset is not deterministic!
+
+f: A function mapping elements of `input_dataset`, concatenated with
+ `other_arguments`, to a Dataset variant that contains elements matching
+ `output_types` and `output_shapes`.
+)doc");
+
+REGISTER_OP("GroupByWindowDataset")
+ .Input("input_dataset: variant")
+ .Input("key_func_other_arguments: Tkey_func_other_arguments")
+ .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
+ .Input(
+ "window_size_func_other_arguments: Twindow_size_func_other_arguments")
+ .Output("handle: variant")
+ .Attr("key_func: func")
+ .Attr("reduce_func: func")
+ .Attr("window_size_func: func")
+ .Attr("Tkey_func_other_arguments: list(type) >= 0")
+ .Attr("Treduce_func_other_arguments: list(type) >= 0")
+ .Attr("Twindow_size_func_other_arguments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that computes a windowed group-by on `input_dataset`.
+
+// TODO(mrry): Support non-int64 keys.
+
+key_func: A function mapping an element of `input_dataset`, concatenated
+ with `key_func_other_arguments` to a scalar value of type DT_INT64.
+)doc");
+
REGISTER_OP("FilterDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@@ -304,6 +408,27 @@ padding_values: A list of scalars containing the padding value to use for
each of the outputs.
)doc");
+REGISTER_OP("DenseToSparseBatchDataset")
+ .Input("input_dataset: variant")
+ .Input("batch_size: int64")
+ .Input("row_shape: int64")
+ .Output("handle: variant")
+ // NOTE(mrry): the 0th and 2nd elements will be DT_INT64.
+ .Attr("output_types: list(type) >= 1")
+ // NOTE(mrry): the 1st and 2nd elements will be vectors.
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that yields a SparseTensor for each element of the input.
+
+input_dataset: A handle to an input dataset. Must have a single component.
+batch_size: A scalar representing the number of elements to accumulate in a
+ batch.
+row_shape: A vector representing the dense shape of each row in the produced
+ SparseTensor. The shape may be partially specified, using `-1` to indicate
+ that a particular dimension should use the maximum size of all batch elements.
+)doc");
+
REGISTER_OP("RangeDataset")
.Input("start: int64")
.Input("stop: int64")
@@ -389,6 +514,24 @@ compression_type: A scalar containing either (i) the empty string (no
buffer_size: A scalar containing the number of bytes to buffer.
)doc");
+REGISTER_OP("SqlDataset")
+ .Input("driver_name: string")
+ .Input("data_source_name: string")
+ .Input("query: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that executes a SQL query and emits rows of the result set.
+
+driver_name: The database type. Currently, the only supported type is 'sqlite'.
+data_source_name: A connection string to connect to the database.
+query: A SQL query to execute.
+)doc");
+
REGISTER_OP("FixedLengthRecordDataset")
.Input("filenames: string")
.Input("header_bytes: int64")
@@ -519,6 +662,36 @@ REGISTER_OP("IteratorGetNext")
Gets the next output from the given iterator.
)doc");
+REGISTER_OP("DatasetToSingleElement")
+ .Input("dataset: variant")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Outputs the single element from the given dataset.
+
+dataset: A handle to a dataset that contains a single element.
+components: The components of the single element of `input`.
+)doc");
+
REGISTER_OP("IteratorToStringHandle")
.Input("resource_handle: resource")
.Output("string_handle: string")
@@ -547,4 +720,28 @@ output_shapes: If specified, defines the shape of each tuple component in an
element produced by the resulting iterator.
)doc");
+REGISTER_OP("SerializeIterator")
+ .Input("resource_handle: resource")
+ .Output("serialized: variant")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Converts the given `resource_handle` representing an iterator to a variant tensor.
+
+resource_handle: A handle to an iterator resource.
+serialized: A variant tensor storing the state of the iterator contained in the
+ resource.
+)doc");
+
+REGISTER_OP("DeserializeIterator")
+ .Input("resource_handle: resource")
+ .Input("serialized: variant")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Converts the given variant tensor to an iterator and stores it in the given resource.
+
+resource_handle: A handle to an iterator resource.
+serialized: A variant tensor storing the state of the iterator contained in the
+ resource.
+)doc");
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
index 11cb9861a3..e6995821df 100644
--- a/tensorflow/core/ops/logging_ops.cc
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -43,7 +43,7 @@ REGISTER_OP("Print")
.Output("output: T")
.SetIsStateful()
.Attr("T: type")
- .Attr("U: list(type)")
+ .Attr("U: list(type) >= 0")
.Attr("message: string = ''")
.Attr("first_n: int = -1")
.Attr("summarize: int = 3")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 2c73441e7d..1447736676 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -6064,6 +6064,32 @@ op {
description: "By default, this op performs an inclusive cumsum, which means that the first\nelement of the input is identical to the first element of the output:\n\n```python\ntf.cumsum([a, b, c]) # => [a, a + b, a + b + c]\n```\n\nBy setting the `exclusive` kwarg to `True`, an exclusive cumsum is\nperformed instead:\n\n```python\ntf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b]\n```\n\nBy setting the `reverse` kwarg to `True`, the cumsum is performed in the\nopposite direction:\n\n```python\ntf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c]\n```\n\nThis is more efficient than using separate `tf.reverse` ops.\n\nThe `reverse` and `exclusive` kwargs can also be combined:\n\n```python\ntf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0]\n```"
}
op {
+ name: "DatasetToSingleElement"
+ input_arg {
+ name: "dataset"
+ description: "A handle to a dataset that contains a single element."
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "components"
+ description: "The components of the single element of `input`."
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ summary: "Outputs the single element from the given dataset."
+}
+op {
name: "DebugGradientIdentity"
input_arg {
name: "input"
@@ -6695,6 +6721,41 @@ op {
description: "See SetOperationOp::SetOperationFromContext for values of `set_operation`.\n\nOutput `result` is a `SparseTensor` represented by `result_indices`,\n`result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this\nhas rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth`\ndimension contains the result of `set_operation` applied to the corresponding\n`[0...n-1]` dimension of `set`."
}
op {
+ name: "DenseToSparseBatchDataset"
+ input_arg {
+ name: "input_dataset"
+ description: "A handle to an input dataset. Must have a single component."
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ description: "A scalar representing the number of elements to accumulate in a\nbatch."
+ type: DT_INT64
+ }
+ input_arg {
+ name: "row_shape"
+ description: "A vector representing the dense shape of each row in the produced\nSparseTensor. The shape may be partially specified, using `-1` to indicate\nthat a particular dimension should use the maximum size of all batch elements."
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ summary: "Creates a dataset that yields a SparseTensor for each element of the input."
+}
+op {
name: "DenseToSparseSetOperation"
input_arg {
name: "set1"
@@ -7034,6 +7095,21 @@ op {
description: "[min_range, max_range] are scalar floats that specify the range for\nthe \'input\' data. The \'mode\' attribute controls exactly which calculations are\nused to convert the float values to their quantized equivalents.\n\nIn \'MIN_COMBINED\' mode, each value of the tensor will undergo the following:\n\n```\nif T == qint8, in[i] += (range(T) + 1)/ 2.0\nout[i] = min_range + (in[i]* (max_range - min_range) / range(T))\n```\nhere `range(T) = numeric_limits<T>::max() - numeric_limits<T>::min()`\n\n*MIN_COMBINED Mode Example*\n\nIf the input comes from a QuantizedRelu6, the output type is\nquint8 (range of 0-255) but the possible range of QuantizedRelu6 is\n0-6. The min_range and max_range values are therefore 0.0 and 6.0.\nDequantize on quint8 will take each value, cast to float, and multiply\nby 6 / 255.\nNote that if quantizedtype is qint8, the operation will additionally add\neach value by 128 prior to casting.\n\nIf the mode is \'MIN_FIRST\', then this approach is used:\n\n```c++\nnum_discrete_values = 1 << (# of bits in T)\nrange_adjust = num_discrete_values / (num_discrete_values - 1)\nrange = (range_max - range_min) * range_adjust\nrange_scale = range / num_discrete_values\nconst double offset_input = static_cast<double>(input) - lowest_quantized;\nresult = range_min + ((input - numeric_limits<T>::min()) * range_scale)\n```\n\n*SCALED mode Example*\n\n`SCALED` mode matches the quantization approach used in\n`QuantizeAndDequantize{V2|V3}`.\n\nIf the mode is `SCALED`, we do not use the full range of the output type,\nchoosing to elide the lowest possible value for symmetry (e.g., output range is\n-127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to\n0.\n\nWe first find the range of values in our tensor. The\nrange we use is always centered on 0, so we find m such that\n```c++\n m = max(abs(input_min), abs(input_max))\n```\n\nOur input tensor range is then `[-m, m]`.\n\nNext, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`.\nIf T is signed, this is\n```\n num_bits = sizeof(T) * 8\n [min_fixed, max_fixed] =\n [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1]\n```\n\nOtherwise, if T is unsigned, the fixed-point range is\n```\n [min_fixed, max_fixed] = [0, (1 << num_bits) - 1]\n```\n\nFrom this we compute our scaling factor, s:\n```c++\n s = (2 * m) / (max_fixed - min_fixed)\n```\n\nNow we can dequantize the elements of our tensor:\n```c++\nresult = input * s\n```"
}
op {
+ name: "DeserializeIterator"
+ input_arg {
+ name: "resource_handle"
+ description: "A handle to an iterator resource."
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "serialized"
+ description: "A variant tensor storing the state of the iterator contained in the\nresource."
+ type: DT_VARIANT
+ }
+ summary: "Converts the given variant tensor to an iterator and stores it in the given resource."
+ is_stateful: true
+}
+op {
name: "DeserializeManySparse"
input_arg {
name: "serialized_sparse"
@@ -10148,6 +10224,71 @@ op {
description: "*NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting\n[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)"
}
op {
+ name: "GroupByWindowDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "key_func_other_arguments"
+ type_list_attr: "Tkey_func_other_arguments"
+ }
+ input_arg {
+ name: "reduce_func_other_arguments"
+ type_list_attr: "Treduce_func_other_arguments"
+ }
+ input_arg {
+ name: "window_size_func_other_arguments"
+ type_list_attr: "Twindow_size_func_other_arguments"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "key_func"
+ type: "func"
+ description: "A function mapping an element of `input_dataset`, concatenated\nwith `key_func_other_arguments` to a scalar value of type DT_INT64."
+ }
+ attr {
+ name: "reduce_func"
+ type: "func"
+ }
+ attr {
+ name: "window_size_func"
+ type: "func"
+ }
+ attr {
+ name: "Tkey_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Treduce_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Twindow_size_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ summary: "Creates a dataset that computes a windowed group-by on `input_dataset`."
+ description: "// TODO(mrry): Support non-int64 keys."
+}
+op {
name: "HSVToRGB"
input_arg {
name: "images"
@@ -10608,6 +10749,30 @@ op {
description: "The upper regularized incomplete Gamma function is defined as:\n\n\\\\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\\\)\n\nwhere\n\n\\\\(Gamma(a, x) = int_{x}^{\\infty} t^{a-1} exp(-t) dt\\\\)\n\nis the upper incomplete Gama function.\n\nNote, above `P(a, x)` (`Igamma`) is the lower regularized complete\nGamma function."
}
op {
+ name: "IgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ summary: "Creates a dataset that contains the elements of `input_dataset` ignoring errors."
+}
+op {
name: "Imag"
input_arg {
name: "input"
@@ -12379,6 +12544,54 @@ op {
is_stateful: true
}
op {
+ name: "MapAndBatchDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "batch_size"
+ description: "A scalar representing the number of elements to accumulate in a\nbatch. It determines the number of concurrent invocations of `f` that process\nelements from `input_dataset` in parallel."
+ type: DT_INT64
+ }
+ input_arg {
+ name: "num_parallel_batches"
+ description: "A scalar representing the number of batches to create in\nparallel. Processing multiple batches in parallel benefits workloads prone to\nstragglers."
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ summary: "Creates a dataset that applies `f` to the outputs of `input_dataset` and then"
+ description: "batches `batch_size` of them.\n\nUnlike a \"MapDataset\", which applies `f` sequentially, this dataset invokes up\nto `batch_size * num_parallel_batches` copies of `f` in parallel."
+}
+op {
name: "MapClear"
attr {
name: "capacity"
@@ -16049,6 +16262,57 @@ op {
description: "Builds a merged tensor such that\n\n```python\n merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...]\n```\n\nFor example, if each `indices[m]` is scalar or vector, we have\n\n```python\n # Scalar indices:\n merged[indices[m], ...] = data[m][...]\n\n # Vector indices:\n merged[indices[m][i], ...] = data[m][i, ...]\n```\n\nEach `data[i].shape` must start with the corresponding `indices[i].shape`,\nand the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we\nmust have `data[i].shape = indices[i].shape + constant`. In terms of this\n`constant`, the output shape is\n\n merged.shape = [max(indices)] + constant\n\nValues may be merged in parallel, so if an index appears in both `indices[m][i]`\nand `indices[n][j]`, the result may be invalid. This differs from the normal\nDynamicStitch operator that defines the behavior in that case.\n\nFor example:\n\n```python\n indices[0] = 6\n indices[1] = [4, 1]\n indices[2] = [[5, 2], [0, 3]]\n data[0] = [61, 62]\n data[1] = [[41, 42], [11, 12]]\n data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]\n merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42],\n [51, 52], [61, 62]]\n```\n\nThis method can be used to merge partitions created by `dynamic_partition`\nas illustrated on the following example:\n\n```python\n # Apply function (increments x_i) on elements for which a certain condition\n # apply (x_i != -1 in this example).\n x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4])\n condition_mask=tf.not_equal(x,tf.constant(-1.))\n partitioned_data = tf.dynamic_partition(\n x, tf.cast(condition_mask, tf.int32) , 2)\n partitioned_data[1] = partitioned_data[1] + 1.0\n condition_indices = tf.dynamic_partition(\n tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2)\n x = tf.dynamic_stitch(condition_indices, partitioned_data)\n # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain\n # unchanged.\n```\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"https://www.tensorflow.org/images/DynamicStitch.png\" alt>\n</div>"
}
op {
+ name: "ParallelInterleaveDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "cycle_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "block_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "sloppy"
+ type: DT_BOOL
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ description: "A function mapping elements of `input_dataset`, concatenated with\n`other_arguments`, to a Dataset variant that contains elements matching\n`output_types` and `output_shapes`."
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
+ description: "The resulting dataset is similar to the `InterleaveDataset`, with the exception\nthat if retrieving the next value from a dataset would cause the requester to\nblock, it will skip that input dataset. This dataset is especially useful\nwhen loading data from a variable-latency datastores (e.g. HDFS, GCS), as it\nallows the training step to proceed so long as some data is available.\n\n!! WARNING !! This dataset is not deterministic!"
+}
+op {
name: "ParallelMapDataset"
input_arg {
name: "input_dataset"
@@ -16718,7 +16982,6 @@ op {
name: "U"
type: "list(type)"
has_minimum: true
- minimum: 1
}
attr {
name: "message"
@@ -23856,6 +24119,53 @@ op {
description: "The input `tags` and `values` must have the same shape. The generated summary\nhas a summary value for each tag-value pair in `tags` and `values`."
}
op {
+ name: "ScanDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "initial_state"
+ type_list_attr: "Tstate"
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Tstate"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ summary: "Creates a dataset successively reduces `f` over the elements of `input_dataset`."
+}
+op {
name: "ScatterAdd"
input_arg {
name: "ref"
@@ -25050,6 +25360,21 @@ op {
summary: "Computes gradients for the scaled exponential linear (Selu) operation."
}
op {
+ name: "SerializeIterator"
+ input_arg {
+ name: "resource_handle"
+ description: "A handle to an iterator resource."
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "serialized"
+ description: "A variant tensor storing the state of the iterator contained in the\nresource."
+ type: DT_VARIANT
+ }
+ summary: "Converts the given `resource_handle` representing an iterator to a variant tensor."
+ is_stateful: true
+}
+op {
name: "SerializeManySparse"
input_arg {
name: "sparse_indices"
@@ -28960,6 +29285,42 @@ op {
summary: "Splits a tensor into `num_split` tensors along one dimension."
}
op {
+ name: "SqlDataset"
+ input_arg {
+ name: "driver_name"
+ description: "The database type. Currently, the only supported type is \'sqlite\'."
+ type: DT_STRING
+ }
+ input_arg {
+ name: "data_source_name"
+ description: "A connection string to connect to the database."
+ type: DT_STRING
+ }
+ input_arg {
+ name: "query"
+ description: "A SQL query to execute."
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ summary: "Creates a dataset that executes a SQL query and emits rows of the result set."
+ is_stateful: true
+}
+op {
name: "Sqrt"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/summary_ops.cc b/tensorflow/core/ops/summary_ops.cc
index f778b48797..5efbac7ad7 100644
--- a/tensorflow/core/ops/summary_ops.cc
+++ b/tensorflow/core/ops/summary_ops.cc
@@ -49,6 +49,33 @@ flush_millis: How often, in milliseconds, to flush the pending events and
filename_suffix: Every event file's name is suffixed with this suffix.
)doc");
+REGISTER_OP("CreateSummaryDbWriter")
+ .Input("writer: resource")
+ .Input("db_uri: string")
+ .Input("experiment_name: string")
+ .Input("run_name: string")
+ .Input("user_name: string")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Creates summary database writer accessible by given resource handle.
+
+This can be used to write tensors from the execution graph directly
+to a database. Only SQLite is supported right now. This function
+will create the schema if it doesn't exist. Entries in the Users,
+Experiments, and Runs tables will be created automatically if they
+don't already exist.
+
+writer: Handle to SummaryWriter resource to overwrite.
+db_uri: For example "file:/tmp/foo.sqlite".
+experiment_name: Can't contain ASCII control characters or <>. Case
+ sensitive. If empty, then the Run will not be associated with any
+ Experiment.
+run_name: Can't contain ASCII control characters or <>. Case sensitive.
+ If empty, then each Tag will not be associated with any Run.
+user_name: Must be valid as both a DNS label and Linux username. If
+ empty, then the Experiment will not be associated with any User.
+)doc");
+
REGISTER_OP("FlushSummaryWriter")
.Input("writer: resource")
.SetShapeFn(shape_inference::NoOutputs)
@@ -89,6 +116,20 @@ summary_metadata: Serialized SummaryMetadata protocol buffer containing
plugin-related metadata for this summary.
)doc");
+REGISTER_OP("ImportEvent")
+ .Input("writer: resource")
+ .Input("event: string")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Outputs a `tf.Event` protocol buffer.
+
+When CreateSummaryDbWriter is being used, this op can be useful for
+importing data from event logs.
+
+writer: A handle to a summary writer.
+event: A string containing a binary-encoded tf.Event proto.
+)doc");
+
REGISTER_OP("WriteScalarSummary")
.Input("writer: resource")
.Input("global_step: int64")
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index e82aebad0b..17fe704b79 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -247,7 +247,7 @@ class GcsRandomAccessFile : public RandomAccessFile {
/// The implementation of reads with an LRU block cache. Thread safe.
Status Read(uint64 offset, size_t n, StringPiece* result,
char* scratch) const override {
- result->clear();
+ *result = StringPiece();
std::vector<char> out;
TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, &out));
std::memcpy(scratch, out.data(), std::min(out.size(), n));
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 6225c2c705..5eeb861bdd 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -458,16 +458,25 @@ def tf_additional_lib_deps():
def tf_additional_core_deps():
return select({
+ "//tensorflow:with_gcp_support_windows_override": [],
+ "//tensorflow:with_gcp_support_android_override": [],
+ "//tensorflow:with_gcp_support_ios_override": [],
"//tensorflow:with_gcp_support": [
"//tensorflow/core/platform/cloud:gcs_file_system",
],
"//conditions:default": [],
}) + select({
+ "//tensorflow:with_hdfs_support_windows_override": [],
+ "//tensorflow:with_hdfs_support_android_override": [],
+ "//tensorflow:with_hdfs_support_ios_override": [],
"//tensorflow:with_hdfs_support": [
"//tensorflow/core/platform/hadoop:hadoop_file_system",
],
"//conditions:default": [],
}) + select({
+ "//tensorflow:with_s3_support_windows_override": [],
+ "//tensorflow:with_s3_support_android_override": [],
+ "//tensorflow:with_s3_support_ios_override": [],
"//tensorflow:with_s3_support": [
"//tensorflow/core/platform/s3:s3_file_system",
],
@@ -477,9 +486,9 @@ def tf_additional_core_deps():
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_op_deps():
return select({
- "//tensorflow:windows": [],
- "//tensorflow:android": [],
- "//tensorflow:ios": [],
+ "//tensorflow:with_gcp_support_windows_override": [],
+ "//tensorflow:with_gcp_support_android_override": [],
+ "//tensorflow:with_gcp_support_ios_override": [],
"//tensorflow:with_gcp_support": [
"//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
],
@@ -489,9 +498,9 @@ def tf_additional_cloud_op_deps():
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_kernel_deps():
return select({
- "//tensorflow:windows": [],
- "//tensorflow:android": [],
- "//tensorflow:ios": [],
+ "//tensorflow:with_gcp_support_windows_override": [],
+ "//tensorflow:with_gcp_support_android_override": [],
+ "//tensorflow:with_gcp_support_ios_override": [],
"//tensorflow:with_gcp_support": [
"//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
],
diff --git a/tensorflow/core/util/bcast.cc b/tensorflow/core/util/bcast.cc
index 47e6ddb3d8..1eab7e3d02 100644
--- a/tensorflow/core/util/bcast.cc
+++ b/tensorflow/core/util/bcast.cc
@@ -68,9 +68,7 @@ BCast::BCast(const Vec& sx, const Vec& sy, const bool fewer_dims_optimization) {
// Output shape.
State curr = UNKNOWN;
const int64 x_i = x[i]; // i-th dimension of x.
- CHECK_GE(x_i, 0);
const int64 y_i = y[i]; // i-th dimension of y.
- CHECK_GE(y_i, 0);
int64 o_i; // i-th dimension of the output.
int64 bx_i; // i-th broadcast for x.
int64 by_i; // i-th broadcast for y.
diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc
index 2d797c855a..90c3fed2e8 100644
--- a/tensorflow/core/util/device_name_utils.cc
+++ b/tensorflow/core/util/device_name_utils.cc
@@ -116,7 +116,6 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
if (fullname == "/") {
return true;
}
- StringPiece tmp;
while (!fullname.empty()) {
bool progress = false;
if (str_util::ConsumePrefix(&fullname, "/job:")) {
diff --git a/tensorflow/core/util/memmapped_file_system.cc b/tensorflow/core/util/memmapped_file_system.cc
index e077e94cf8..a0f43d2d4a 100644
--- a/tensorflow/core/util/memmapped_file_system.cc
+++ b/tensorflow/core/util/memmapped_file_system.cc
@@ -58,12 +58,13 @@ class RandomAccessFileFromMemmapped : public RandomAccessFile {
Status Read(uint64 offset, size_t to_read, StringPiece* result,
char* scratch) const override {
if (offset >= length_) {
- result->set(scratch, 0);
+ *result = StringPiece(scratch, 0);
return Status(error::OUT_OF_RANGE, "Read after file end");
}
const uint64 region_left =
std::min(length_ - offset, static_cast<uint64>(to_read));
- result->set(reinterpret_cast<const uint8*>(data_) + offset, region_left);
+ *result =
+ StringPiece(reinterpret_cast<const char*>(data_) + offset, region_left);
return (region_left == to_read)
? Status::OK()
: Status(error::OUT_OF_RANGE, "Read less bytes than requested");
diff --git a/tensorflow/core/util/semver_test.cc b/tensorflow/core/util/semver_test.cc
index 0647f670c7..fdc34fa58b 100644
--- a/tensorflow/core/util/semver_test.cc
+++ b/tensorflow/core/util/semver_test.cc
@@ -39,7 +39,7 @@ bool ConsumeDotSeparatedIdentifiers(StringPiece* s, const string& prefix,
for (i = 0; i < s->size() && IsDotOrIdentifierChar((*s)[i]); ++i) {
// Intentionally empty
}
- val->set(s->data(), i);
+ *val = StringPiece(s->data(), i);
s->remove_prefix(i);
return i > 0;
}
diff --git a/tensorflow/docs_src/mobile/index.md b/tensorflow/docs_src/mobile/index.md
index a6f1422f6f..06ad47bc62 100644
--- a/tensorflow/docs_src/mobile/index.md
+++ b/tensorflow/docs_src/mobile/index.md
@@ -35,8 +35,8 @@ speech-driven interface, and many of these require on-device processing. Most of
the time a user isn’t giving commands, and so streaming audio continuously to a
remote server would be a waste of bandwidth, since it would mostly be silence or
background noises. To solve this problem it’s common to have a small neural
-network running on-device @{$tutorials/audio_recognition$listening out for a
-particular keyword}. Once that keyword has been spotted, the rest of the
+network running on-device @{$tutorials/audio_recognition$listening out for a particular keyword}.
+Once that keyword has been spotted, the rest of the
conversation can be transmitted over to the server for further processing if
more computing power is needed.
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 3ca3b51a5e..ccced8792e 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -901,6 +901,95 @@ are all 0. Figure below shows examples of different `edge_padding` and
<img style="width:100%" src="https://www.tensorflow.org/images/ops_pad.png">
</div>
+## Recv
+
+See also
+[`ComputationBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+
+<b> `Recv(shape, channel_handle)` </b>
+
+| Arguments | Type | Semantics |
+| ---------------- | --------------- | ------------------------------------ |
+| `shape` | `Shape` | shape of the data to receive |
+| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair |
+
+Receives data of the given shape from a `Send` instruction in another
+computation that shares the same channel handle. Returns a
+ComputationDataHandle for the received data.
+
+The client API of `Recv` operation represents synchronous communication.
+However, the instruction is internally decomposed into 2 HLO instructions
+(`Recv` and `RecvDone`) to enable asynchronous data transfers. See also
+[`HloInstruction::CreateRecv` and `HloInstruction::CreateRecvDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
+
+<b>`Recv(const Shape& shape, int64 channel_id)`</b>
+
+Allocates resources required to receive data from a `Send` instruction with the
+same channel_id. Returns a context for the allocated resources, which is used
+by a following `RecvDone` instruction to wait for the completion of the data
+transfer. The context is a tuple of {receive buffer (shape), request identifier
+(U32)} and it can only be used by a `RecvDone` instruction.
+
+<b> `RecvDone(HloInstruction context)` </b>
+
+Given a context created by a `Recv` instruction, waits for the data transfer to
+complete and returns the received data.
+
+## Send
+
+See also
+[`ComputationBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+
+<b> `Send(operand, channel_handle)` </b>
+
+| Arguments | Type | Semantics |
+| ---------------- | ----------------------- | -------------------------------- |
+| `operand` | `ComputationDataHandle` | data to send (array of type T) |
+| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair |
+
+Sends the given operand data to a `Recv` instruction in another computation
+that shares the same channel handle. Does not return any data.
+
+Similar to the `Recv` operation, the client API of `Send` operation represents
+synchronous communication, and is internally decomposed into 2 HLO instructions
+(`Send` and `SendDone`) to enable asynchronous data transfers. See also
+[`HloInstruction::CreateSend` and `HloInstruction::CreateSendDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
+
+<b>`Send(HloInstruction operand, int64 channel_id)`</b>
+
+Initiates an asynchronous transfer of the operand to the resources allocated by
+the `Recv` instruction with the same channel id. Returns a context, which is
+used by a following `SendDone` instruction to wait for the completion of the
+data transfer. The context is a tuple of {operand (shape), request identifier
+(U32)} and it can only be used by a `SendDone` instruction.
+
+<b> `SendDone(HloInstruction context)` </b>
+
+Given a context created by a `Send` instruction, waits for the data transfer to
+complete. The instruction does not return any data.
+
+<b> Scheduling of channel instructions </b>
+
+The execution order of the 4 instructions for each channel (`Recv`, `RecvDone`,
+`Send`, `SendDone`) is as below.
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:70%" src="../../images/send_recv_order.png">
+</div>
+
+* `Recv` happens before `Send`
+* `Send` happens before `RecvDone`
+* `Recv` happens before `RecvDone`
+* `Send` happens before `SendDone`
+
+When the backend compilers generate a linear schedule for each computation that
+communicates via channel instructions, there must not be cycles across the
+computations. For example, below schedules lead to deadlocks.
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%" src="../../images/send_recv_schedule.png">
+</div>
+
## Reduce
See also
diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/programmers_guide/debugger.md
index dd5496b08e..25cb72008d 100644
--- a/tensorflow/docs_src/programmers_guide/debugger.md
+++ b/tensorflow/docs_src/programmers_guide/debugger.md
@@ -520,8 +520,12 @@ model.fit(...) # This will break into the TFDBG CLI.
## Debugging tf-slim with TFDBG
-TFDBG currently supports only training with
+TFDBG supports debugging of training and evaluation with
[tf-slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim).
+As detailed below, training and evaluation require slightly different debugging
+workflows.
+
+### Debugging training in tf-slim
To debug the training process, provide `LocalCLIDebugWrapperSession` to the
`session_wrapper` argument of `slim.learning.train()`. For example:
@@ -530,13 +534,31 @@ import tensorflow as tf
from tensorflow.python import debug as tf_debug
# ... Code that creates the graph and the train_op ...
-tf.contrib.slim.learning_train(
+tf.contrib.slim.learning.train(
train_op,
logdir,
number_of_steps=10,
session_wrapper=tf_debug.LocalCLIDebugWrapperSession)
```
+### Debugging evaluation in tf-slim
+To debug the evaluation process, provide `LocalCLIDebugHook` to the
+`hooks` argument of `slim.evaluation.evaluate_once()`. For example:
+
+``` python
+import tensorflow as tf
+from tensorflow.python import debug as tf_debug
+
+# ... Code that creates the graph and the eval and final ops ...
+tf.contrib.slim.evaluation.evaluate_once(
+ '',
+ checkpoint_path,
+ logdir,
+ eval_op=my_eval_op,
+ final_op=my_value_op,
+ hooks=[tf_debug.LocalCLIDebugHook()])
+```
+
## Offline Debugging of Remotely-Running Sessions
Often, your model is running on a remote machine or a process that you don't
diff --git a/tensorflow/examples/image_retraining/README.md b/tensorflow/examples/image_retraining/README.md
new file mode 100644
index 0000000000..8a49525c6e
--- /dev/null
+++ b/tensorflow/examples/image_retraining/README.md
@@ -0,0 +1,12 @@
+retrain.py is an example script that shows how one can adapt a pretrained
+network for other classification problems. A detailed overview of this script
+can be found at:
+https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0
+
+The script also shows how one can train layers
+with quantized weights and activations instead of taking a pre-trained floating
+point model and then quantizing weights and activations.
+The output graphdef produced by this script is compatible with the TensorFlow
+Lite Optimizing Converter and can be converted to TFLite format.
+
+
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 3549891461..ebddfb20f4 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -69,11 +69,18 @@ to validate that you have gathered good training data, but if you want to deploy
on resource-limited platforms, you can try the `--architecture` flag with a
Mobilenet model. For example:
+Run floating-point version of mobilenet:
```bash
python tensorflow/examples/image_retraining/retrain.py \
--image_dir ~/flower_photos --architecture mobilenet_1.0_224
```
+Run quantized version of mobilenet:
+```bash
+python tensorflow/examples/image_retraining/retrain.py \
+ --image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quantized
+```
+
There are 32 different Mobilenet models to choose from, with a variety of file
size and latency options. The first number can be '1.0', '0.75', '0.50', or
'0.25' to control the size, and the second controls the input image size, either
@@ -107,6 +114,7 @@ import numpy as np
from six.moves import urllib
import tensorflow as tf
+from tensorflow.contrib.quantize.python import quant_ops
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import gfile
@@ -271,6 +279,7 @@ def create_model_graph(model_info):
"""
with tf.Graph().as_default() as graph:
model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name'])
+ print('Model path: ', model_path)
with gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
@@ -337,7 +346,10 @@ def maybe_download_and_extract(data_url):
statinfo = os.stat(filepath)
tf.logging.info('Successfully downloaded', filename, statinfo.st_size,
'bytes.')
- tarfile.open(filepath, 'r:gz').extractall(dest_directory)
+ print('Extracting file from ', filepath)
+ tarfile.open(filepath, 'r:gz').extractall(dest_directory)
+ else:
+ print('Not extracting or downloading files, model already present in disk')
def ensure_dir_exists(dir_name):
@@ -733,7 +745,7 @@ def variable_summaries(var):
def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
- bottleneck_tensor_size):
+ bottleneck_tensor_size, quantize_layer):
"""Adds a new softmax and fully-connected layer for training.
We need to retrain the top layer to identify our new classes, so this function
@@ -745,10 +757,12 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
Args:
class_count: Integer of how many categories of things we're trying to
- recognize.
+ recognize.
final_tensor_name: Name string for the new final node that produces results.
bottleneck_tensor: The output of the main CNN graph.
bottleneck_tensor_size: How many entries in the bottleneck vector.
+ quantize_layer: Boolean, specifying whether the newly added layer should be
+ quantized.
Returns:
The tensors for the training and cross entropy results, and tensors for the
@@ -771,18 +785,41 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
with tf.name_scope('weights'):
initial_value = tf.truncated_normal(
[bottleneck_tensor_size, class_count], stddev=0.001)
-
layer_weights = tf.Variable(initial_value, name='final_weights')
+ if quantize_layer:
+ quantized_layer_weights = quant_ops.MovingAvgQuantize(
+ layer_weights, is_training=True)
+ variable_summaries(quantized_layer_weights)
variable_summaries(layer_weights)
with tf.name_scope('biases'):
layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
+ if quantize_layer:
+ quantized_layer_biases = quant_ops.MovingAvgQuantize(
+ layer_biases, is_training=True)
+ variable_summaries(quantized_layer_biases)
+
variable_summaries(layer_biases)
+
with tf.name_scope('Wx_plus_b'):
- logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
- tf.summary.histogram('pre_activations', logits)
+ if quantize_layer:
+ logits = tf.matmul(bottleneck_input,
+ quantized_layer_weights) + quantized_layer_biases
+ logits = quant_ops.MovingAvgQuantize(
+ logits,
+ init_min=-32.0,
+ init_max=32.0,
+ is_training=True,
+ num_bits=8,
+ narrow_range=False,
+ ema_decay=0.5)
+ tf.summary.histogram('pre_activations', logits)
+ else:
+ logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
+ tf.summary.histogram('pre_activations', logits)
final_tensor = tf.nn.softmax(logits, name=final_tensor_name)
+
tf.summary.histogram('activations', final_tensor)
with tf.name_scope('cross_entropy'):
@@ -790,6 +827,7 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
labels=ground_truth_input, logits=logits)
with tf.name_scope('total'):
cross_entropy_mean = tf.reduce_mean(cross_entropy)
+
tf.summary.scalar('cross_entropy', cross_entropy_mean)
with tf.name_scope('train'):
@@ -825,6 +863,7 @@ def add_evaluation_step(result_tensor, ground_truth_tensor):
def save_graph_to_file(sess, graph, graph_file_name):
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
+
with gfile.FastGFile(graph_file_name, 'wb') as f:
f.write(output_graph_def.SerializeToString())
return
@@ -858,6 +897,7 @@ def create_model_info(architecture):
ValueError: If architecture name is unknown.
"""
architecture = architecture.lower()
+ is_quantized = False
if architecture == 'inception_v3':
# pylint: disable=line-too-long
data_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
@@ -902,19 +942,28 @@ def create_model_info(architecture):
architecture)
return None
is_quantized = True
- data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
- data_url += version_string + '_' + size_string + '_frozen.tgz'
- bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
+
+ if is_quantized:
+ data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
+ data_url += version_string + '_' + size_string + '_quantized_frozen.tgz'
+ bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
+ resized_input_tensor_name = 'Placeholder:0'
+ model_dir_name = ('mobilenet_v1_' + version_string + '_' + size_string +
+ '_quantized_frozen')
+ model_base_name = 'quantized_frozen_graph.pb'
+
+ else:
+ data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
+ data_url += version_string + '_' + size_string + '_frozen.tgz'
+ bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
+ resized_input_tensor_name = 'input:0'
+ model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
+ model_base_name = 'frozen_graph.pb'
+
bottleneck_tensor_size = 1001
input_width = int(size_string)
input_height = int(size_string)
input_depth = 3
- resized_input_tensor_name = 'input:0'
- if is_quantized:
- model_base_name = 'quantized_graph.pb'
- else:
- model_base_name = 'frozen_graph.pb'
- model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
model_file_name = os.path.join(model_dir_name, model_base_name)
input_mean = 127.5
input_std = 127.5
@@ -933,6 +982,7 @@ def create_model_info(architecture):
'model_file_name': model_file_name,
'input_mean': input_mean,
'input_std': input_std,
+ 'quantize_layer': is_quantized,
}
@@ -1028,7 +1078,7 @@ def main(_):
(train_step, cross_entropy, bottleneck_input, ground_truth_input,
final_tensor) = add_final_training_ops(
len(image_lists.keys()), FLAGS.final_tensor_name, bottleneck_tensor,
- model_info['bottleneck_tensor_size'])
+ model_info['bottleneck_tensor_size'], model_info['quantize_layer'])
# Create the operations we need to evaluate the accuracy of our new layer.
evaluation_step, prediction = add_evaluation_step(
diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py
index c342a17dd8..2de4c4ec99 100644
--- a/tensorflow/examples/image_retraining/retrain_test.py
+++ b/tensorflow/examples/image_retraining/retrain_test.py
@@ -70,10 +70,18 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
def testAddFinalTrainingOps(self, flags_mock):
with tf.Graph().as_default():
with tf.Session() as sess:
- bottleneck = tf.placeholder(
- tf.float32, [1, 1024],
- name='bottleneck')
- retrain.add_final_training_ops(5, 'final', bottleneck, 1024)
+ bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck')
+ # Test creating final training op with quantization
+ retrain.add_final_training_ops(5, 'final', bottleneck, 1024, False)
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
+
+ @tf.test.mock.patch.object(retrain, 'FLAGS', learning_rate=0.01)
+ def testAddFinalTrainingOpsQuantized(self, flags_mock):
+ with tf.Graph().as_default():
+ with tf.Session() as sess:
+ bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck')
+ # Test creating final training op with quantization
+ retrain.add_final_training_ops(5, 'final', bottleneck, 1024, True)
self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
def testAddEvaluationStep(self):
@@ -99,5 +107,12 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
self.assertIsNotNone(model_info)
self.assertEqual(299, model_info['input_width'])
+ def testCreateModelInfoQuantized(self):
+ # Test for mobilenet_quantized
+ model_info = retrain.create_model_info('mobilenet_1.0_224')
+ self.assertIsNotNone(model_info)
+ self.assertEqual(224, model_info['input_width'])
+
+
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/examples/learn/iris.py b/tensorflow/examples/learn/iris.py
index 0a50b3ba87..03e60972aa 100644
--- a/tensorflow/examples/learn/iris.py
+++ b/tensorflow/examples/learn/iris.py
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Example of DNNClassifier for Iris plant dataset."""
+"""Example of DNNClassifier for Iris plant dataset.
+
+This example uses APIs in Tensorflow 1.4 or above.
+"""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/examples/learn/wide_n_deep_tutorial.py b/tensorflow/examples/learn/wide_n_deep_tutorial.py
index e447b3e24e..072353392a 100644
--- a/tensorflow/examples/learn/wide_n_deep_tutorial.py
+++ b/tensorflow/examples/learn/wide_n_deep_tutorial.py
@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Example code for TensorFlow Wide & Deep Tutorial using TF.Learn API."""
+"""Example code for TensorFlow Wide & Deep Tutorial using TF High Level API.
+
+This example uses APIs in Tensorflow 1.4 or above.
+"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 4e5d17f76f..eb79da5384 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -62,6 +62,29 @@ func WriteScalarSummary(scope *Scope, writer tf.Output, global_step tf.Output, t
return scope.AddOperation(opspec)
}
+// Outputs a `tf.Event` protocol buffer.
+//
+// When CreateSummaryDbWriter is being used, this op can be useful for
+// importing data from event logs.
+//
+// Arguments:
+// writer: A handle to a summary writer.
+// event: A string containing a binary-encoded tf.Event proto.
+//
+// Returns the created operation.
+func ImportEvent(scope *Scope, writer tf.Output, event tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ImportEvent",
+ Input: []tf.Input{
+ writer, event,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
// Outputs a `Summary` protocol buffer with a tensor.
//
// Arguments:
@@ -3983,41 +4006,6 @@ func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value t
return op.Output(0)
}
-// Identity op for gradient debugging.
-//
-// This op is hidden from public in Python. It is used by TensorFlow Debugger to
-// register gradient tensors for gradient debugging.
-func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "DebugGradientIdentity",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Deprecated. Use TensorArrayGradV3
-func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"source": source}
- opspec := tf.OpSpec{
- Type: "TensorArrayGradV2",
- Input: []tf.Input{
- handle, flow_in,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Get the current size of the TensorArray.
//
// Arguments:
@@ -4551,31 +4539,6 @@ func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr)
return scope.AddOperation(opspec)
}
-// Concatenates tensors along one dimension.
-//
-// Arguments:
-// values: List of `N` Tensors to concatenate. Their ranks and types must match,
-// and their sizes must match in all dimensions except `concat_dim`.
-// axis: 0-D. The dimension along which to concatenate. Must be in the
-// range [-rank(values), rank(values)).
-//
-// Returns A `Tensor` with the concatenation of values stacked along the
-// `concat_dim` dimension. This tensor's shape matches that of `values` except
-// in `concat_dim` where it has the sum of the sizes.
-func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ConcatV2",
- Input: []tf.Input{
- tf.OutputList(values), axis,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// QueueDequeueUpToV2Attr is an optional argument to QueueDequeueUpToV2.
type QueueDequeueUpToV2Attr func(optionalAttr)
@@ -4992,80 +4955,6 @@ func PriorityQueueV2(scope *Scope, shapes []tf.Shape, optional ...PriorityQueueV
return op.Output(0)
}
-// FIFOQueueV2Attr is an optional argument to FIFOQueueV2.
-type FIFOQueueV2Attr func(optionalAttr)
-
-// FIFOQueueV2Shapes sets the optional shapes attribute to value.
-//
-// value: The shape of each component in a value. The length of this attr must
-// be either 0 or the same as the length of component_types. If the length of
-// this attr is 0, the shapes of queue elements are not constrained, and
-// only one element may be dequeued at a time.
-// If not specified, defaults to <>
-//
-// REQUIRES: len(value) >= 0
-func FIFOQueueV2Shapes(value []tf.Shape) FIFOQueueV2Attr {
- return func(m optionalAttr) {
- m["shapes"] = value
- }
-}
-
-// FIFOQueueV2Capacity sets the optional capacity attribute to value.
-//
-// value: The upper bound on the number of elements in this queue.
-// Negative numbers mean no limit.
-// If not specified, defaults to -1
-func FIFOQueueV2Capacity(value int64) FIFOQueueV2Attr {
- return func(m optionalAttr) {
- m["capacity"] = value
- }
-}
-
-// FIFOQueueV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this queue is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func FIFOQueueV2Container(value string) FIFOQueueV2Attr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// FIFOQueueV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this queue will be shared under the given name
-// across multiple sessions.
-// If not specified, defaults to ""
-func FIFOQueueV2SharedName(value string) FIFOQueueV2Attr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// A queue that produces elements in first-in first-out order.
-//
-// Arguments:
-// component_types: The type of each component in a value.
-//
-// Returns The handle to the queue.
-func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQueueV2Attr) (handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"component_types": component_types}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "FIFOQueueV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// StridedSliceAttr is an optional argument to StridedSlice.
type StridedSliceAttr func(optionalAttr)
@@ -5445,6 +5334,101 @@ func DynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged
return op.Output(0)
}
+// FIFOQueueV2Attr is an optional argument to FIFOQueueV2.
+type FIFOQueueV2Attr func(optionalAttr)
+
+// FIFOQueueV2Shapes sets the optional shapes attribute to value.
+//
+// value: The shape of each component in a value. The length of this attr must
+// be either 0 or the same as the length of component_types. If the length of
+// this attr is 0, the shapes of queue elements are not constrained, and
+// only one element may be dequeued at a time.
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func FIFOQueueV2Shapes(value []tf.Shape) FIFOQueueV2Attr {
+ return func(m optionalAttr) {
+ m["shapes"] = value
+ }
+}
+
+// FIFOQueueV2Capacity sets the optional capacity attribute to value.
+//
+// value: The upper bound on the number of elements in this queue.
+// Negative numbers mean no limit.
+// If not specified, defaults to -1
+func FIFOQueueV2Capacity(value int64) FIFOQueueV2Attr {
+ return func(m optionalAttr) {
+ m["capacity"] = value
+ }
+}
+
+// FIFOQueueV2Container sets the optional container attribute to value.
+//
+// value: If non-empty, this queue is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func FIFOQueueV2Container(value string) FIFOQueueV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// FIFOQueueV2SharedName sets the optional shared_name attribute to value.
+//
+// value: If non-empty, this queue will be shared under the given name
+// across multiple sessions.
+// If not specified, defaults to ""
+func FIFOQueueV2SharedName(value string) FIFOQueueV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// A queue that produces elements in first-in first-out order.
+//
+// Arguments:
+// component_types: The type of each component in a value.
+//
+// Returns The handle to the queue.
+func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQueueV2Attr) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"component_types": component_types}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "FIFOQueueV2",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Converts the given `resource_handle` representing an iterator to a variant tensor.
+//
+// Arguments:
+// resource_handle: A handle to an iterator resource.
+//
+// Returns A variant tensor storing the state of the iterator contained in the
+// resource.
+func SerializeIterator(scope *Scope, resource_handle tf.Output) (serialized tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SerializeIterator",
+ Input: []tf.Input{
+ resource_handle,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Return a tensor with the same shape and contents as the input tensor or value.
func Identity(scope *Scope, input tf.Output) (output tf.Output) {
if scope.Err() != nil {
@@ -5576,6 +5560,39 @@ func IteratorToStringHandle(scope *Scope, resource_handle tf.Output) (string_han
return op.Output(0)
}
+// Outputs the single element from the given dataset.
+//
+// Arguments:
+// dataset: A handle to a dataset that contains a single element.
+//
+//
+//
+// Returns The components of the single element of `input`.
+func DatasetToSingleElement(scope *Scope, dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "DatasetToSingleElement",
+ Input: []tf.Input{
+ dataset,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if components, idx, err = makeOutputList(op, idx, "components"); err != nil {
+ scope.UpdateErr("DatasetToSingleElement", err)
+ return
+ }
+ return components
+}
+
// Gets the next output from the given iterator.
func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) {
if scope.Err() != nil {
@@ -5696,6 +5713,30 @@ func FixedLengthRecordDataset(scope *Scope, filenames tf.Output, header_bytes tf
return op.Output(0)
}
+// Creates a dataset that executes a SQL query and emits rows of the result set.
+//
+// Arguments:
+// driver_name: The database type. Currently, the only supported type is 'sqlite'.
+// data_source_name: A connection string to connect to the database.
+// query: A SQL query to execute.
+//
+//
+func SqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "SqlDataset",
+ Input: []tf.Input{
+ driver_name, data_source_name, query,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// PlaceholderAttr is an optional argument to Placeholder.
type PlaceholderAttr func(optionalAttr)
@@ -5766,6 +5807,68 @@ func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, out
return op.Output(0)
}
+// Identity op for gradient debugging.
+//
+// This op is hidden from public in Python. It is used by TensorFlow Debugger to
+// register gradient tensors for gradient debugging.
+func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "DebugGradientIdentity",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Deprecated. Use TensorArrayGradV3
+func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"source": source}
+ opspec := tf.OpSpec{
+ Type: "TensorArrayGradV2",
+ Input: []tf.Input{
+ handle, flow_in,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Creates a dataset that yields a SparseTensor for each element of the input.
+//
+// Arguments:
+// input_dataset: A handle to an input dataset. Must have a single component.
+// batch_size: A scalar representing the number of elements to accumulate in a
+// batch.
+// row_shape: A vector representing the dense shape of each row in the produced
+// SparseTensor. The shape may be partially specified, using `-1` to indicate
+// that a particular dimension should use the maximum size of all batch elements.
+//
+//
+func DenseToSparseBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, row_shape tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "DenseToSparseBatchDataset",
+ Input: []tf.Input{
+ input_dataset, batch_size, row_shape,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Creates a dataset that batches and pads `batch_size` elements from the input.
//
// Arguments:
@@ -5826,6 +5929,69 @@ func TensorArrayConcatV2(scope *Scope, handle tf.Output, flow_in tf.Output, dtyp
return op.Output(0), op.Output(1)
}
+// Converts the given variant tensor to an iterator and stores it in the given resource.
+//
+// Arguments:
+// resource_handle: A handle to an iterator resource.
+// serialized: A variant tensor storing the state of the iterator contained in the
+// resource.
+//
+// Returns the created operation.
+func DeserializeIterator(scope *Scope, resource_handle tf.Output, serialized tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "DeserializeIterator",
+ Input: []tf.Input{
+ resource_handle, serialized,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Concatenates tensors along one dimension.
+//
+// Arguments:
+// values: List of `N` Tensors to concatenate. Their ranks and types must match,
+// and their sizes must match in all dimensions except `concat_dim`.
+// axis: 0-D. The dimension along which to concatenate. Must be in the
+// range [-rank(values), rank(values)).
+//
+// Returns A `Tensor` with the concatenation of values stacked along the
+// `concat_dim` dimension. This tensor's shape matches that of `values` except
+// in `concat_dim` where it has the sum of the sizes.
+func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ConcatV2",
+ Input: []tf.Input{
+ tf.OutputList(values), axis,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Creates a dataset that contains the elements of `input_dataset` ignoring errors.
+func IgnoreErrorsDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "IgnoreErrorsDataset",
+ Input: []tf.Input{
+ input_dataset,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Creates a dataset that concatenates `input_dataset` with `another_dataset`.
func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
if scope.Err() != nil {
@@ -22311,6 +22477,39 @@ func QuantizedBiasAdd(scope *Scope, input tf.Output, bias tf.Output, min_input t
return op.Output(0), op.Output(1), op.Output(2)
}
+// Creates summary database writer accessible by given resource handle.
+//
+// This can be used to write tensors from the execution graph directly
+// to a database. Only SQLite is supported right now. This function
+// will create the schema if it doesn't exist. Entries in the Users,
+// Experiments, and Runs tables will be created automatically if they
+// don't already exist.
+//
+// Arguments:
+// writer: Handle to SummaryWriter resource to overwrite.
+// db_uri: For example "file:/tmp/foo.sqlite".
+// experiment_name: Can't contain ASCII control characters or <>. Case
+// sensitive. If empty, then the Run will not be associated with any
+// Experiment.
+// run_name: Can't contain ASCII control characters or <>. Case sensitive.
+// If empty, then each Tag will not be associated with any Run.
+// user_name: Must be valid as both a DNS label and Linux username. If
+// empty, then the Experiment will not be associated with any User.
+//
+// Returns the created operation.
+func CreateSummaryDbWriter(scope *Scope, writer tf.Output, db_uri tf.Output, experiment_name tf.Output, run_name tf.Output, user_name tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "CreateSummaryDbWriter",
+ Input: []tf.Input{
+ writer, db_uri, experiment_name, run_name, user_name,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
// HistogramFixedWidthAttr is an optional argument to HistogramFixedWidth.
type HistogramFixedWidthAttr func(optionalAttr)
diff --git a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
index 2b431eebf5..499757e8cf 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
@@ -43,7 +43,6 @@ final class NativeLibrary {
private static final boolean DEBUG =
System.getProperty("org.tensorflow.NativeLibrary.DEBUG") != null;
private static final String JNI_LIBNAME = "tensorflow_jni";
- private static final String FRAMEWORK_LIBNAME = "tensorflow_framework";
public static void load() {
if (isLoaded() || tryLoadLibrary()) {
@@ -59,12 +58,15 @@ final class NativeLibrary {
}
// Native code is not present, perhaps it has been packaged into the .jar file containing this.
// Extract the JNI library itself
- final String jniResourceName = makeResourceName(JNI_LIBNAME);
+ final String jniLibName = System.mapLibraryName(JNI_LIBNAME);
+ final String jniResourceName = makeResourceName(jniLibName);
log("jniResourceName: " + jniResourceName);
final InputStream jniResource =
NativeLibrary.class.getClassLoader().getResourceAsStream(jniResourceName);
// Extract the JNI's dependency
- final String frameworkResourceName = makeResourceName(FRAMEWORK_LIBNAME);
+ final String frameworkLibName =
+ maybeAdjustForMacOS(System.mapLibraryName("tensorflow_framework"));
+ final String frameworkResourceName = makeResourceName(frameworkLibName);
log("frameworkResourceName: " + frameworkResourceName);
final InputStream frameworkResource =
NativeLibrary.class.getClassLoader().getResourceAsStream(frameworkResourceName);
@@ -88,12 +90,15 @@ final class NativeLibrary {
tempPath.deleteOnExit();
final String tempDirectory = tempPath.toString();
if (frameworkResource != null) {
- extractResource(frameworkResource, FRAMEWORK_LIBNAME, tempDirectory);
+ extractResource(frameworkResource, frameworkLibName, tempDirectory);
} else {
- log(frameworkResourceName + " not found. This is fine assuming " + jniResourceName
- + " is not built to depend on it.");
+ log(
+ frameworkResourceName
+ + " not found. This is fine assuming "
+ + jniResourceName
+ + " is not built to depend on it.");
}
- System.load(extractResource(jniResource, JNI_LIBNAME, tempDirectory));
+ System.load(extractResource(jniResource, jniLibName, tempDirectory));
} catch (IOException e) {
throw new UnsatisfiedLinkError(
String.format(
@@ -121,9 +126,27 @@ final class NativeLibrary {
}
}
+ private static String maybeAdjustForMacOS(String libFilename) {
+ if (!System.getProperty("os.name").contains("OS X")) {
+ return libFilename;
+ }
+ // This is macOS, and the TensorFlow release process might have setup dependencies on
+ // libtensorflow_framework.so instead of libtensorflow_framework.dylib. Adjust for that.
+ final ClassLoader cl = NativeLibrary.class.getClassLoader();
+ if (cl.getResource(makeResourceName(libFilename)) != null) {
+ return libFilename;
+ }
+ // liftensorflow_framework.dylib not found, try libtensorflow_framework.so
+ final String suffix = ".dylib";
+ if (!libFilename.endsWith(suffix)) {
+ return libFilename;
+ }
+ return libFilename.substring(0, libFilename.length() - suffix.length()) + ".so";
+ }
+
private static String extractResource(
InputStream resource, String resourceName, String extractToDirectory) throws IOException {
- final File dst = new File(extractToDirectory, System.mapLibraryName(resourceName));
+ final File dst = new File(extractToDirectory, resourceName);
dst.deleteOnExit();
final String dstPath = dst.toString();
log("extracting native library to: " + dstPath);
@@ -157,9 +180,7 @@ final class NativeLibrary {
}
private static String makeResourceName(String baseName) {
- return "org/tensorflow/native/"
- + String.format("%s-%s/", os(), architecture())
- + System.mapLibraryName(baseName);
+ return "org/tensorflow/native/" + String.format("%s-%s/", os(), architecture()) + baseName;
}
private static long copy(InputStream src, File dstFile) throws IOException {
diff --git a/tensorflow/python/client/session_clusterspec_prop_test.py b/tensorflow/python/client/session_clusterspec_prop_test.py
index b77912b4f7..28a4dd27a7 100644
--- a/tensorflow/python/client/session_clusterspec_prop_test.py
+++ b/tensorflow/python/client/session_clusterspec_prop_test.py
@@ -169,7 +169,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
# BaseRemoteRendezvous::SameWorkerRecvDone that means the test doesn't
# actually capture the motivating bug unless run on a GPU machine.
#
- # Example error message (before bugfix -- linebreaks added because lint):
+ # Example error message (before bugfix -- line breaks added because lint):
#
# W0718 17:14:41.521534 190121 device_mgr.cc:107] Unknown device:
# /job:worker/replica:0/task:0/device:CPU:0 all devices:
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index f45bc13602..40731aba7d 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -344,16 +344,6 @@ bool PyTensorListToVector(PyObject* py_tensor_list,
%rename("_TF_SetConfig") TF_SetConfig;
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
-// Create temporary int64_t to pass to TF_OperationGetAttrInt
-%typemap(in, numinputs=0) int64_t* value (int64_t val) {
- $1 = &val;
-}
-
-// Convert value to Python int
-%typemap(argout) int64_t* value {
- $result = PyInt_FromLong(*$1);
-}
-
%include "tensorflow/c/c_api.h"
%include "tensorflow/c/python_api.h"
diff --git a/tensorflow/python/client/timeline.py b/tensorflow/python/client/timeline.py
index f3ba4244ce..1e96ac5ed4 100644
--- a/tensorflow/python/client/timeline.py
+++ b/tensorflow/python/client/timeline.py
@@ -275,7 +275,7 @@ class _TensorTracker(object):
name: The name of the Tensor as a string.
object_id: Chrome Trace object identifier assigned for this Tensor.
timestamp: The creation timestamp of this event as a long integer.
- pid: Process identifier of the assicaiated device, as an integer.
+ pid: Process identifier of the associated device, as an integer.
allocator: Name of the allocator used to create the Tensor.
num_bytes: Number of bytes allocated (long integer).
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 343f316281..09f4349cf3 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -22,6 +22,7 @@ import collections
import threading
import numpy as np
+import six
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
@@ -105,7 +106,7 @@ class Dataset(object):
def make_one_shot_iterator(self):
"""Creates an `Iterator` for enumerating the elements of this dataset.
- **N.B.** The returned iterator will be initialized automatically.
+ Note: The returned iterator will be initialized automatically.
A "one-shot" iterator does not currently support re-initialization.
Returns:
@@ -124,7 +125,18 @@ class Dataset(object):
def _make_dataset():
return self._as_variant_tensor() # pylint: disable=protected-access
- _make_dataset.add_to_graph(ops.get_default_graph())
+ try:
+ _make_dataset.add_to_graph(ops.get_default_graph())
+ except ValueError as err:
+ if "Cannot capture a stateful node" in str(err):
+ raise ValueError(
+ "Failed to create a one-shot iterator for a dataset. "
+ "`Dataset.make_one_shot_iterator()` does not support datasets that "
+ "capture stateful objects, such as a `Variable` or `LookupTable`. "
+ "In these cases, use `Dataset.make_initializable_iterator()`. "
+ "(Original error: %s)" % err)
+ else:
+ six.reraise(ValueError, err)
return iterator_ops.Iterator(
gen_dataset_ops.one_shot_iterator(
diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
index d987ba84b5..acea9433e2 100644
--- a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
@@ -111,6 +111,20 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
self.assertEqual(repr(None), dump.run_feed_keys_info)
+ def testDumpingOnASingleRunWorksWithRelativePathForDebugDumpDir(self):
+ sess = dumping_wrapper.DumpingDebugWrapperSession(
+ self.sess, session_root=self.session_root, log_usage=False)
+ sess.run(self.inc_v)
+ dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
+ cwd = os.getcwd()
+ try:
+ os.chdir(self.session_root)
+ dump = debug_data.DebugDumpDir(
+ os.path.relpath(dump_dirs[0], self.session_root))
+ self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))
+ finally:
+ os.chdir(cwd)
+
def testDumpingOnASingleRunWithFeedDictWorks(self):
sess = dumping_wrapper.DumpingDebugWrapperSession(
self.sess, session_root=self.session_root, log_usage=False)
@@ -350,12 +364,14 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
thread_name_filter=r"MainThread$")
self.assertAllClose(1.0, sess.run(self.delta))
+ child_thread_result = []
def child_thread_job():
- sess.run(sess.run(self.eta))
+ child_thread_result.append(sess.run(self.eta))
thread = threading.Thread(name="ChildThread", target=child_thread_job)
thread.start()
thread.join()
+ self.assertAllClose([-1.4], child_thread_result)
dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
self.assertEqual(1, len(dump_dirs))
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index bcd1e1d0dc..c36647b21c 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -14,11 +14,16 @@ cc_library(
"pywrap_tensor.cc",
"pywrap_tfe_src.cc",
],
- hdrs = ["pywrap_tfe.h"],
+ hdrs = [
+ "pywrap_tensor.h",
+ "pywrap_tfe.h",
+ ],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
+ "//tensorflow/c:c_api_internal",
"//tensorflow/c/eager:c_api",
+ "//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:tape",
"//tensorflow/core:lib",
"//tensorflow/python:ndarray_tensor",
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 86b3776b8c..f9d6d8aa5e 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -120,6 +120,7 @@ _tracing = False
# gradient function registration site, to be less error-prone
# TODO(apassos) add ops other than those in nn_grad and math_grad
_ops_which_dont_need_outputs = set([
+ "Identity",
"MatMul",
"Conv2DBackpropInput",
"Conv2DBackpropFilter",
@@ -195,6 +196,7 @@ _ops_which_dont_need_outputs = set([
])
_ops_which_dont_need_inputs = set([
+ "Identity",
"Softmax",
"LogSoftmax",
"BiasAdd",
@@ -727,12 +729,24 @@ def _num_elements(grad):
raise ValueError("`grad` not a Tensor or IndexedSlices.")
+_last_shape_dtype = [None, None]
+_last_zero = [None]
+
+
+def _zeros(shape, dtype):
+ """Wraps array_ops.zeros to cache last zero for a given shape and dtype."""
+ if [shape, dtype] != _last_shape_dtype:
+ _last_shape_dtype[:] = [shape, dtype]
+ _last_zero[0] = array_ops.zeros(shape, dtype)
+ return _last_zero[0]
+
+
_default_vspace = imperative_grad.VSpace(
num_elements_fn=_num_elements,
aggregate_fn=_aggregate_grads,
tensor_id=ops.tensor_id,
- zeros=array_ops.zeros,
- ones_like=lambda x: ops.convert_to_tensor(array_ops.ones_like(x)))
+ zeros=_zeros,
+ ones=array_ops.ones)
class GradientTape(object):
@@ -821,5 +835,5 @@ class GradientTape(object):
for x in sources]
grad = imperative_grad.imperative_grad(
_default_vspace, self._tape, [target], sources)
- self.tape = None
+ self._tape = None
return grad
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index ed54b8e12e..86c9cce3fd 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -24,11 +24,11 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_gradient
-from tensorflow.python.eager import imperative_grad
from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@@ -41,7 +41,6 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import training
-from tensorflow.python.util import compat
class BackpropTest(test.TestCase):
@@ -103,6 +102,18 @@ class BackpropTest(test.TestCase):
grad_fn = backprop.gradients_function(f)
self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
+ def testErrors(self):
+
+ @custom_gradient.custom_gradient
+ def f(x):
+ def grad(_):
+ raise RuntimeError('x')
+ return x, grad
+
+ # TODO(apassos) raise the right error here
+ with self.assertRaises(errors_impl.InternalError):
+ backprop.gradients_function(f)(constant_op.constant(1.0))
+
def testImplicitGradOverEmbeddingLookup(self):
batch_size = 8
embedding_size = 512
@@ -293,6 +304,17 @@ class BackpropTest(test.TestCase):
grad = g.gradient(y, [x])[0]
self.assertEqual(grad.numpy(), 6.0)
+ def testGradientTapeGradientCalledMultipleTimes(self):
+ with backprop.GradientTape() as g:
+ x = constant_op.constant(3.0)
+ g.watch(x)
+ y = x * x
+ z = y * y
+ g.gradient(z, [x])
+ with self.assertRaisesRegexp(
+ RuntimeError, 'GradientTape.gradient can only be called once'):
+ g.gradient(y, [x])
+
def testGradientTapeVariable(self):
v = resource_variable_ops.ResourceVariable(1.0, name='v')
with backprop.GradientTape() as g:
@@ -483,48 +505,6 @@ class BackpropTest(test.TestCase):
initial_value=1., name='testSameObjectForMultipleArguments.Variable')
self.assertAllEqual([1., 1.], np_g(v, v))
- def testEarlyGradAggregation(self):
- # Needs to be a list so mutations by the callback affect this function.
- add_n = []
- def callback(op_type, unused_1, unused_2, unused_3, unused_4):
- if compat.as_bytes(op_type) == compat.as_bytes('AddN'):
- add_n.append(1)
- context.context().add_post_execution_callback(callback)
-
- v = resource_variable_ops.ResourceVariable(constant_op.constant(2.0),
- name='v')
- def fn():
- outputs = []
- for _ in range(20):
- outputs.append(v * constant_op.constant(2.0))
- return math_ops.add_n(outputs)
-
- # By default the aggregation count is 2.
- _ = backprop.implicit_grad(fn)()[0][1]
- self.assertEqual(len(add_n), 2)
- del add_n[:]
-
- # Reduce the aggregation limit, cause the backprop to do some
- # early aggregation.
- # pylint: disable=protected-access
- old_cnt = imperative_grad._MIN_AGGREGATE_COUNT
- old_bytes = imperative_grad._MIN_AGGREGATE_BYTES
- imperative_grad._MIN_AGGREGATE_COUNT = 10
- imperative_grad._MIN_AGGREGATE_BYTES = 1
- _ = backprop.implicit_grad(fn)()
- self.assertEqual(len(add_n), 6)
- del add_n[:]
-
- # Aggregation is also limited by the memory.
- imperative_grad._MIN_AGGREGATE_BYTES = 10000
- _ = backprop.implicit_grad(fn)()
- self.assertEqual(len(add_n), 2)
-
- imperative_grad._MIN_AGGREGATE_COUNT = old_cnt
- imperative_grad._MIN_AGGREGATE_BYTES = old_bytes
- # pylint: enable=protected-access
- context.context().clear_post_execution_callbacks()
-
def testImplicitGradientsCustomGradientAndCachedVariableValue(self):
@custom_gradient.custom_gradient
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 26a70a617d..435505edd7 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -37,6 +37,7 @@ from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
@@ -66,7 +67,8 @@ class MicroBenchmarks(test.Benchmark):
func()
end = time.time()
mean_us = (end - start) * 1e6 / num_iters
- self.report_benchmark(iters=num_iters, wall_time=mean_us)
+ self.report_benchmark(iters=num_iters, wall_time=mean_us,
+ extras={"examples_per_sec": num_iters/(end-start)})
def benchmark_create_np_array(self):
func = lambda: np.array([3.0])
@@ -133,6 +135,10 @@ class MicroBenchmarks(test.Benchmark):
func = lambda: m * m
self._run(func, num_iters)
+ def _benchmark_tf_multiply_op(self, m, num_iters):
+ func = lambda: math_ops.multiply(m, m)
+ self._run(func, num_iters)
+
def benchmark_np_multiply(self):
self._benchmark_np_multiply(self._m_2, 30000)
@@ -148,6 +154,47 @@ class MicroBenchmarks(test.Benchmark):
m = self._m_2.gpu()
self._benchmark_tf_multiply(m, 30000)
+ def benchmark_tf_multiply_op_CPU(self):
+ with context.device(CPU):
+ m = self._m_2.cpu()
+ self._benchmark_tf_multiply_op(m, 30000)
+
+ def benchmark_tf_multiply_op_GPU(self):
+ if not context.num_gpus():
+ return
+ with context.device(GPU):
+ m = self._m_2.gpu()
+ self._benchmark_tf_multiply_op(m, 30000)
+
+ def benchmark_tf_identity(self):
+ m = self._m_2
+ self._run(lambda: gen_array_ops.identity(m), 30000)
+
+ def benchmark_tf_gradient_function_identity(self):
+ m = self._m_2
+ self._run(
+ lambda: backprop.gradients_function(gen_array_ops.identity, [0])(m),
+ 30000)
+
+ def benchmark_tf_gradient_forward_identity(self):
+ with backprop.GradientTape() as tape:
+ m = self._m_2
+ tape.watch(m)
+ self._run(lambda: gen_array_ops.identity(m), 30000)
+
+ def benchmark_tf_gradient_tape_push_pop(self):
+
+ def f():
+ with backprop.GradientTape():
+ pass
+ self._run(f, 30000)
+
+ def benchmark_tf_gradient_function_no_op(self):
+ m = self._m_2
+ self._run(
+ lambda: backprop.gradients_function(lambda x: x, [0])(m),
+ 30000)
+
def _benchmark_np_matmul(self, m, transpose_b, num_iters):
a = m.cpu().numpy()
b = a.T if transpose_b else a
diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py
index 983c1ea73e..c6457232e9 100644
--- a/tensorflow/python/eager/execute.py
+++ b/tensorflow/python/eager/execute.py
@@ -47,8 +47,7 @@ def execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
name: Customized name for the operation.
Returns:
- None if there are no outputs, a single Tensor object if there is one output
- and a list of Tensor objects if there are multiple outputs.
+ List of output Tensor objects. The list is empty if there are no outputs
Raises:
An exception on error.
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index b1b1de0c41..c542dd77a6 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -407,9 +407,15 @@ def _get_defun_inputs(args):
def _defun_internal(name, func, args, kwds):
"""Defines and returns graph-mode version of func."""
+ container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access
with context.graph_mode():
captures = {}
tmp_graph = CapturingGraph(captures)
+ # Inherit the container prefix, since this is used for error checking when
+ # isolating eager execution (the container prefix at creation must match the
+ # container prefix when used, and variables accessed in the defun will be
+ # used in the outside context).
+ tmp_graph._container_prefix = container_prefix # pylint: disable=protected-access
# Copy the graph collections to ensure summaries and other things work. This
# lets the function access (but not mutate) collections of the containing
# graph, such as the global step and the summary writer collections.
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 243efccac4..65776ca177 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -62,13 +62,40 @@ class FunctionTest(test.TestCase):
@function.defun
def step():
def inner():
- tape.watch_variable(v)
return v * v
return backprop.implicit_grad(inner)()[0][0]
self.assertAllEqual(step(), 2.0)
+ def testDefunReadVariable(self):
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ @function.defun
+ def f():
+ return v.read_value()
+
+ self.assertEqual(1.0, float(f()))
+
+ def testDefunAssignAddVariable(self):
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ @function.defun
+ def f():
+ v.assign_add(2.0)
+ return v.read_value()
+
+ self.assertEqual(3.0, float(f()))
+
+ def testDefunDifferentiable(self):
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ @function.defun
+ def f():
+ return v * v
+
+ self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0)
+
def testGraphModeCaptureVariable(self):
with context.graph_mode(), self.test_session() as sess:
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index a7f1061d18..ce51d17cfc 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -247,7 +247,9 @@ def _get_graph_callable_inputs(shape_and_dtypes):
ret.append(_get_graph_callable_inputs(x))
else:
raise errors.InvalidArgumentError(
- None, None, "shape_and_dtypes not ShapeAndDtype, type: %s " % type(x))
+ None, None, "Expected the argument to @graph_callable to be a "
+ "(possibly nested) list or tuple of ShapeAndDtype objects, "
+ "but got an object of type: %s" % type(x))
return tuple(ret) if isinstance(shape_and_dtypes, tuple) else ret
@@ -267,7 +269,7 @@ def _graph_callable_internal(func, shape_and_dtypes):
Args:
func: The tfe Python function to compile.
- shape_and_dtypes: A list of type ShapeAndDtype.
+ shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects.
Raises:
ValueError: If any one of func's outputs is not a Tensor.
@@ -430,9 +432,10 @@ def graph_callable(shape_and_dtypes):
ret = foo(tfe.Tensor(2.0)) # `ret` here now is a Tensor with value 9.0.
```
Args:
- shape_and_dtypes: A list of type ShapeAndDtype that specifies shape and type
- information for each of the callable's arguments. The length of this list
- must be equal to the number of arguments accepted by the wrapped function.
+ shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects
+ that specifies shape and type information for each of the callable's
+ arguments. The length of this list must be equal to the number of
+ arguments accepted by the wrapped function.
Returns:
A callable graph object.
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index c87719f84a..837cad974a 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -20,114 +20,13 @@ from __future__ import print_function
import collections
-from tensorflow.python.eager import tape as tape_module
-
-
-# Terminology:
-#
-# - op: a possibly composite operation, which has an entry in the tape
-# - target: dy in dx/dy
-# - source: dx in dx/dy
-# - tensor: one of the many inputs or outputs of an operation
-#
-# Below here we do the gradient algorithm. It works as follows:
-#
-# First we filter the tape to just the subset of operations we want to
-# differentiate. In the process of doing so we count how many times each Tensor
-# is used as an input to an op (so we know when we're done computing gradients
-# for that Tensor). We also count, for each tape entry, how many of its output
-# Tensors need gradients to be computed (Tensors which are not used do not need
-# any gradients to be computed).
-#
-# Finally, we start a backprop stack with a set of tape entries for which we
-# have all gradients available. This set usually is a subset of the set of
-# targets (not all since targets which have outputs in the tape will not have
-# gradients available initially).
-#
-# Then we repeatedly pop an entry from the stack, run its backprop, and update
-# the gradients of its inputs. Once we have computed all gradients for a single
-# input we can mark this input as done, and this can trigger adding an entry to
-# the stack if all outputs of that entry are now done.
-#
-# When the stack is empty we have gradients for all tensors we're interested in.
-def _prepare_backprop(vspace, target, tensor_to_op, op_to_entry, id_sources):
- """Filters the tape to only include relevant entries and counts tensor usages.
-
- Args:
- vspace: information about the space we're differentiating in.
- target: the target to optimize.
- tensor_to_op: Map from tensor id to key in op_to_entry that produced it.
- op_to_entry: Map from op id to a tape.TapeEntry object
- id_sources: the ids of the sources wrt the gradient is being taken.
-
- Returns:
- usage counts (how many entries downstream from a tensor use it)
- op_to_entry_map: entry map (a filtered tape, with only the relevant
- entries),
- missing: map from tensor id to how many downstream gradients still need
- to be computed before this tensor's gradient can be computed.
- """
- tensor_stack = [vspace.tensor_id(x) for x in target]
- tensor_usage_counts = {}
- o_to_e = {} # Copy of just the bits we need from op_to_entry
- while tensor_stack:
- t = tensor_stack.pop()
- op = tensor_to_op.get(t, None)
- # op is None or -1 if the tensor is a source (i.e. was watched directly)
- if op is None or op == -1 or op in o_to_e:
- continue
- op_trace = tape_module.TapeEntry(*op_to_entry[op])
- o_to_e[op] = op_trace
- for it in op_trace.input_ids:
- if it in tensor_usage_counts:
- tensor_usage_counts[it] += 1
- else:
- tensor_usage_counts[it] = 1
- if it not in id_sources and it in tensor_to_op:
- tensor_stack.append(it)
- op_missing_tensor_counts = collections.defaultdict(int)
- for t in tensor_usage_counts:
- if t in tensor_to_op and tensor_to_op[t] is not None:
- op_missing_tensor_counts[tensor_to_op[t]] += 1
- return tensor_usage_counts, o_to_e, op_missing_tensor_counts
-
-
-def _initialize_backprop_stack(op_to_entry, op_missing_tensor):
- """Returns the set of tape entries which are available for backprop."""
- ready_ops = []
- for op in op_to_entry:
- if op not in op_missing_tensor:
- ready_ops.append(op)
- return ready_ops
-
-
-def _initial_gradients(vspace, target, output_gradients, tensor_usage_counts):
- """Computes the initial gradients for each Tensor."""
- # Initialize the backprop stack
- gradients = collections.defaultdict(list)
- for i, t in enumerate(target):
- if vspace.tensor_id(t) in tensor_usage_counts:
- # Can't provide a gradient of something we're trying to differentiate
- assert output_gradients is None or output_gradients[i] is None
- else:
- if output_gradients is None or output_gradients[i] is None:
- out_grad = vspace.ones_like(t)
- else:
- out_grad = output_gradients[i]
- gradients[vspace.tensor_id(t)].append(out_grad)
- return gradients
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import errors
VSpace = collections.namedtuple(
"VSpace",
- ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones_like"])
-
-
-# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total
-# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation
-# so as to release the gradient tensor to save memory.
-_MIN_AGGREGATE_COUNT = 4
-_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024
+ ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones"])
def imperative_grad(
@@ -161,89 +60,6 @@ def imperative_grad(
or if only non-differentiable functions of the source were used in the
computation of target.
"""
- tensor_to_op, op_to_entry = tape.export()
- # This overwrites the op_to_entry variable, which will release all memory used
- # to keep traces that are irrelevant to the gradient computation we're doing
- # here.
- id_sources = [vspace.tensor_id(t) for t in sources]
- tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop(
- vspace, target, tensor_to_op, op_to_entry, id_sources)
- ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor)
- gradients = _initial_gradients(vspace, target, output_gradients,
- tensor_usage_counts)
- gradients_size = dict()
- # Now exhaust the backprop stack
- while ready_ops:
- op = ready_ops.pop()
- op_trace = op_to_entry.pop(op)
- out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids]
-
- # Cache the last used zero tensor. We reuse it if the next one
- # we need is of the same shape and dtype. This is very helpful in
- # large splits and should have negligible overhead in other cases.
- last_shape_and_dtype = None
- last_zeros = None
- for i in range(len(out_gradients)):
- if out_gradients[i] is None:
- # TODO(apassos) this should be in the right device
- none_indices = _grad_fn_accepts_none_for_indices.get(
- op_trace.op_type, None)
- if none_indices is None or i not in none_indices:
- shape_and_dtype = op_trace.output_shape_and_dtype[i]
- if shape_and_dtype != last_shape_and_dtype:
- last_shape_and_dtype = shape_and_dtype
- last_zeros = vspace.zeros(*shape_and_dtype)
- out_gradients[i] = last_zeros
- else:
- out_gradients[i] = vspace.aggregate_fn(out_gradients[i])
-
- in_gradients = op_trace.backward_function(*(out_gradients))
- for i, t in enumerate(op_trace.input_ids):
- if in_gradients[i] is not None:
- t_grads = gradients.setdefault(t, [])
- t_grads.append(in_gradients[i])
- if len(t_grads) >= _MIN_AGGREGATE_COUNT:
- if t not in gradients_size:
- gradients_size[t] = vspace.num_elements_fn(t_grads[-1])
- size = gradients_size[t]
-
- if len(t_grads) * size * 4 > _MIN_AGGREGATE_BYTES:
- t_grads[:] = [vspace.aggregate_fn(t_grads)]
- if tensor_usage_counts.get(t, 0) > 0:
- tensor_usage_counts[t] -= 1
- if (t in tensor_to_op
- and tensor_usage_counts[t] == 0
- and t not in id_sources):
- in_op = tensor_to_op[t]
- if in_op is None or in_op == -1:
- continue
- if op_missing_tensor.get(in_op, 0) > 0:
- op_missing_tensor[in_op] -= 1
- if op_missing_tensor.get(in_op, 0) == 0:
- ready_ops.append(in_op)
- result = []
- for i, s in enumerate(sources):
- g = gradients.get(vspace.tensor_id(s), None)
- if g is None:
- result.append(None)
- else:
- result.append(vspace.aggregate_fn(g))
- return result
-
-
-# TODO(agarwal): use an automatic mechanism for handling None arguments to
-# gradient functions.
-# Some gradient functions can accept None arguments for gradients. The following
-# maps the operation name to the indices at which the corresponding gradient
-# function can accept None values.
-# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values
-# during backprop. However the gradient function uses only the first of those
-# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4],
-# indicates that only the gradient corresponding to index 0 is used, and the
-# gradient values at indices 1-4 are ignored (and hence can be None). The
-# backprop algorithm can then leverage this by not constructing zeros to
-# pass for those indices.
-_grad_fn_accepts_none_for_indices = {
- "SoftmaxCrossEntropyWithLogits": [1],
- "FusedBatchNorm": [1, 2, 3, 4]
-}
+ with errors.raise_exception_on_not_ok_status() as status:
+ return pywrap_tensorflow.TFE_Py_TapeGradient(
+ tape._tape, vspace, target, sources, output_gradients, status) # pylint: disable=protected-access
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index ca283862f9..653f3ef84e 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/python/lib/core/py_seq_tensor.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
+#include "tensorflow/python/eager/pywrap_tensor.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/c/c_api.h"
@@ -573,7 +574,7 @@ bool EagerTensor_CheckExact(const PyObject* o) {
return Py_TYPE(o) == EagerTensorType;
}
-TFE_TensorHandle* EagerTensorHandle(const PyObject* o) {
+TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
return reinterpret_cast<const EagerTensor*>(o)->handle;
}
@@ -594,6 +595,11 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
return reinterpret_cast<PyObject*>(t);
}
+tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
+ CHECK(EagerTensor_CheckExact(tensor));
+ return reinterpret_cast<const EagerTensor*>(tensor)->id;
+}
+
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
if (!PyType_Check(base_class)) {
PyErr_SetString(
diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h
new file mode 100644
index 0000000000..aa1efdd1b8
--- /dev/null
+++ b/tensorflow/python/eager/pywrap_tensor.h
@@ -0,0 +1,25 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
+#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
+
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/python/lib/core/numpy.h"
+
+bool EagerTensor_CheckExact(const PyObject* o);
+tensorflow::int64 EagerTensor_id(const PyObject* tensor);
+
+#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 1d03df2933..6705483f3b 100644
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -81,7 +81,7 @@ bool EagerTensor_CheckExact(const PyObject* o);
PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle);
// Extracts the handle inside EagerTensor object `o`. Returns nullptr on error.
-TFE_TensorHandle* EagerTensorHandle(const PyObject* o);
+TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
// Creates the `EagerTensor` class by subclassing `base_class` and returns the
// newly created type, or nullptr on error.
@@ -103,7 +103,16 @@ void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type,
PyObject* output_tensors,
PyObject* input_tensor_ids,
PyObject* backward_function);
-PyObject* TFE_Py_TapeExport(PyObject* tape);
+
+// Computes a gradient based on information recorded on the tape.`tape` must
+// have been produced by TFE_Py_NewTape. `vspace` must be a
+// imperative_grad.py:VSpace named tuple. `target` and `sources` must be python
+// lists of Tensor objects. `output_gradients` is either None or a python list
+// of either Tensor or None, and if not None should have the same length as
+// target.
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
+ PyObject* target, PyObject* sources,
+ PyObject* output_gradients, TF_Status* status);
// Returns an EagerTensor of dimension [len(`tensor_list`)] containing
// the `slice_dim`'th dimension of each tensor in `tensor_list`. In other words,
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 7456eb10f8..372a6bb4b7 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -16,10 +16,13 @@ limitations under the License.
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tape.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/python/eager/pywrap_tensor.h"
using tensorflow::string;
@@ -440,10 +443,12 @@ void TFE_DeleteContextCapsule(PyObject* context) {
TF_DeleteStatus(status);
}
+using GradientTape = tensorflow::eager::GradientTape<PyObject, PyObject>;
+
typedef struct {
PyObject_HEAD
/* Type-specific fields go here. */
- tensorflow::eager::GradientTape* tape;
+ GradientTape* tape;
} TFE_Py_Tape;
static void TFE_Py_Tape_Delete(PyObject* tape) {
@@ -478,7 +483,7 @@ PyObject* TFE_Py_NewTape() {
TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
- tape->tape = new tensorflow::eager::GradientTape();
+ tape->tape = new GradientTape();
return reinterpret_cast<PyObject*>(tape);
}
@@ -515,18 +520,50 @@ static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
}
PyObject* TFE_Py_TapeShouldRecord(PyObject* py_tape, PyObject* tensors) {
+ if (tensors == Py_None) {
+ Py_RETURN_FALSE;
+ }
+ PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
+ if (seq == nullptr) {
+ return nullptr;
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ // TODO(apassos) consider not building a list and changing the API to check
+ // each tensor individually.
+ std::vector<tensorflow::int64> tensor_ids;
+ tensor_ids.reserve(len);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
+ if (EagerTensor_CheckExact(item)) {
+ tensor_ids.push_back(EagerTensor_id(item));
+ } else {
+ PyObject* id_field = PyObject_GetAttrString(item, "_id");
+ if (id_field == nullptr) {
+ return nullptr;
+ }
+ tensor_ids.push_back(MakeInt(id_field));
+ Py_DECREF(id_field);
+ }
+ }
+ Py_DECREF(seq);
TFE_Py_Tape* tape = reinterpret_cast<TFE_Py_Tape*>(py_tape);
- return PyBool_FromLong(tape->tape->ShouldRecord(MakeIntList(tensors)));
+ if (tape->tape->ShouldRecord(tensor_ids)) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
}
void TFE_Py_TapeWatch(PyObject* tape, tensorflow::int64 tensor_id) {
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
}
-// TODO(apassos) have a fast path for eager tensors here which gets information
-// from the handle instead of from the python object, and use this only for the
-// case of graph tensors.
static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
+ if (EagerTensor_CheckExact(tensor)) {
+ TFE_TensorHandle* t = EagerTensor_Handle(tensor);
+ tensorflow::int64 id = EagerTensor_id(tensor);
+ return tensorflow::eager::TapeTensor{id, t->t.dtype(), t->t.shape()};
+ }
PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
tensorflow::int64 id = MakeInt(id_field);
Py_DECREF(id_field);
@@ -563,11 +600,33 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
return tensorflow::eager::TapeTensor{id, dtype, shape};
}
+std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
+ PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
+ if (seq == nullptr) {
+ return {};
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ std::vector<tensorflow::int64> list;
+ list.reserve(len);
+ for (int i = 0; i < len; ++i) {
+ PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
+ if (EagerTensor_CheckExact(tensor)) {
+ list.push_back(EagerTensor_id(tensor));
+ } else {
+ PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
+ list.push_back(MakeInt(id_field));
+ Py_DECREF(id_field);
+ }
+ }
+ Py_DECREF(seq);
+ return list;
+}
+
void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type,
PyObject* output_tensors,
- PyObject* input_tensor_ids,
+ PyObject* input_tensors,
PyObject* backward_function) {
- std::vector<tensorflow::int64> input_ids = MakeIntList(input_tensor_ids);
+ std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
std::vector<tensorflow::eager::TapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
@@ -582,9 +641,26 @@ void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type,
}
}
Py_DECREF(seq);
+ char* op_type_str = nullptr;
+ if (PyBytes_Check(op_type)) {
+ op_type_str = PyBytes_AsString(op_type);
+ } else if (PyUnicode_Check(op_type)) {
+#if PY_MAJOR_VERSION >= 3
+ op_type_str = PyUnicode_AsUTF8(op_type);
+#else
+ PyObject* py_str = PyUnicode_AsUTF8String(op_type);
+ if (py_str == nullptr) return;
+ op_type_str = PyBytes_AS_STRING(py_str);
+ Py_DECREF(py_str);
+#endif
+ } else {
+ PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
+ return;
+ }
+
Py_INCREF(backward_function);
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->RecordOperation(
- PyBytes_AsString(op_type), output_info, input_ids, backward_function,
+ op_type_str, output_info, input_ids, backward_function,
[backward_function]() { Py_DECREF(backward_function); });
}
@@ -592,64 +668,218 @@ void TFE_Py_TapeDeleteTrace(PyObject* tape, tensorflow::int64 tensor_id) {
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->DeleteTrace(tensor_id);
}
-// TODO(apassos) when backprop.py moves to C most of this exporting logic can
-// disappear.
-PyObject* TFE_Py_TapeExport(PyObject* tape) {
- std::pair<tensorflow::eager::TensorTape, tensorflow::eager::OpTape> exported =
- reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Export();
- PyObject* tensor_tape = PyDict_New();
- for (const auto& pair : exported.first) {
- PyObject* tid = PyLong_FromLong(pair.first);
- PyObject* opid = PyLong_FromLong(pair.second);
- PyDict_SetItem(tensor_tape, tid, opid);
- Py_DECREF(tid);
- Py_DECREF(opid);
- }
-
- PyObject* op_tape = PyDict_New();
- for (const auto& pair : exported.second) {
- PyObject* opid = PyLong_FromLong(pair.first);
- const auto& entry = pair.second;
- PyObject* op_type = PyBytes_FromString(entry.op_type.c_str());
- PyObject* output_ids = PyList_New(entry.output_tensor_info.size());
- for (int i = 0; i < entry.output_tensor_info.size(); ++i) {
- PyObject* tid = PyLong_FromLong(entry.output_tensor_info[i].id);
- PyList_SET_ITEM(output_ids, i, tid);
+class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyObject> {
+ public:
+ explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {}
+
+ tensorflow::Status Initialize() {
+ num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
+ if (num_elements_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
+ if (aggregate_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ zeros_ = PyObject_GetAttrString(py_vspace_, "zeros");
+ if (zeros_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ ones_ =
+ PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_), "ones");
+ if (ones_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ return tensorflow::Status::OK();
+ }
+
+ ~PyVSpace() override {
+ Py_XDECREF(num_elements_);
+ Py_XDECREF(aggregate_fn_);
+ Py_XDECREF(zeros_);
+ Py_XDECREF(ones_);
+ }
+
+ tensorflow::int64 NumElements(PyObject* tensor) const final {
+ PyObject* arglist =
+ Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
+ PyObject* result = PyEval_CallObject(num_elements_, arglist);
+ tensorflow::int64 r = MakeInt(result);
+ Py_DECREF(result);
+ Py_DECREF(arglist);
+ return r;
+ }
+
+ PyObject* AggregateGradients(
+ tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
+ PyObject* list = PyList_New(gradient_tensors.size());
+ for (int i = 0; i < gradient_tensors.size(); ++i) {
+ // Note: stealing a reference to the gradient tensors.
+ CHECK(gradient_tensors[i] != nullptr);
+ CHECK(gradient_tensors[i] != Py_None);
+ PyList_SET_ITEM(list, i,
+ reinterpret_cast<PyObject*>(gradient_tensors[i]));
}
- PyObject* input_ids = PyList_New(entry.input_tensor_id.size());
- for (int i = 0; i < entry.input_tensor_id.size(); ++i) {
- PyObject* tid = PyLong_FromLong(entry.input_tensor_id[i]);
- PyList_SET_ITEM(input_ids, i, tid);
+ PyObject* arglist = Py_BuildValue("(O)", list);
+ CHECK(arglist != nullptr);
+ PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
+ Py_DECREF(arglist);
+ Py_DECREF(list);
+ return result;
+ }
+
+ PyObject* Zeros(tensorflow::TensorShape shape,
+ tensorflow::DataType dtype) const final {
+ PyObject* py_shape = PyTuple_New(shape.dims());
+ for (int i = 0; i < shape.dims(); ++i) {
+ PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
}
- PyObject* backward_function =
- reinterpret_cast<PyObject*>(entry.backward_function);
- PyObject* output_shape_and_dtype =
- PyList_New(entry.output_tensor_info.size());
- for (int i = 0; i < entry.output_tensor_info.size(); ++i) {
- const tensorflow::TensorShape& shape = entry.output_tensor_info[i].shape;
- PyObject* shape_list = PyList_New(shape.dims());
- for (int j = 0; j < shape.dims(); ++j) {
- PyList_SET_ITEM(shape_list, j, PyLong_FromLong(shape.dim_size(j)));
+ PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(zeros_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return reinterpret_cast<PyObject*>(result);
+ }
+
+ PyObject* Ones(tensorflow::TensorShape shape,
+ tensorflow::DataType dtype) const final {
+ PyObject* py_shape = PyTuple_New(shape.dims());
+ for (int i = 0; i < shape.dims(); ++i) {
+ PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
+ }
+ PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(ones_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return result;
+ }
+
+ tensorflow::Status CallBackwardFunction(
+ PyObject* backward_function,
+ tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
+ std::vector<PyObject*>* result) const final {
+ PyObject* grads = PyTuple_New(output_gradients.size());
+ for (int i = 0; i < output_gradients.size(); ++i) {
+ if (output_gradients[i] == nullptr) {
+ Py_INCREF(Py_None);
+ PyTuple_SET_ITEM(grads, i, Py_None);
+ } else {
+ PyTuple_SET_ITEM(grads, i,
+ reinterpret_cast<PyObject*>(output_gradients[i]));
}
- PyObject* type_enum = PyLong_FromLong(entry.output_tensor_info[i].dtype);
- PyObject* tuple = PyTuple_Pack(2, shape_list, type_enum);
- Py_DECREF(shape_list);
- Py_DECREF(type_enum);
- PyList_SET_ITEM(output_shape_and_dtype, i, tuple);
}
- PyObject* opinfo = PyTuple_Pack(5, op_type, output_ids, input_ids,
- backward_function, output_shape_and_dtype);
- Py_DECREF(op_type);
- Py_DECREF(output_ids);
- Py_DECREF(input_ids);
+ PyObject* py_result = PyEval_CallObject(
+ reinterpret_cast<PyObject*>(backward_function), grads);
+ Py_DECREF(grads);
Py_DECREF(backward_function);
- Py_DECREF(output_shape_and_dtype);
- PyDict_SetItem(op_tape, opid, opinfo);
- Py_DECREF(opid);
- Py_DECREF(opinfo);
- }
- PyObject* retval = PyTuple_Pack(2, tensor_tape, op_tape);
- Py_DECREF(tensor_tape);
- Py_DECREF(op_tape);
- return retval;
+ if (py_result == nullptr) {
+ VLOG(1) << "Gradient function threw exceptions";
+ if (VLOG_IS_ON(1)) {
+ PyErr_Print();
+ }
+ return tensorflow::errors::Internal("gradient function threw exceptions");
+ }
+ result->clear();
+ PyObject* seq =
+ PySequence_Fast(py_result, "expected a sequence of gradients");
+ if (seq == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "gradient function did not return a list");
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ VLOG(1) << "Gradient length is " << len;
+ result->reserve(len);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
+ if (item == Py_None) {
+ result->push_back(nullptr);
+ } else {
+ Py_INCREF(item);
+ result->push_back(item);
+ }
+ }
+ Py_DECREF(seq);
+ Py_DECREF(py_result);
+ return tensorflow::Status::OK();
+ }
+
+ void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
+
+ private:
+ PyObject* py_vspace_;
+
+ PyObject* num_elements_;
+ PyObject* aggregate_fn_;
+ PyObject* zeros_;
+ PyObject* ones_;
+};
+
+std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
+ PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
+ if (seq == nullptr) {
+ return {};
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ std::vector<PyObject*> list;
+ list.reserve(len);
+ for (int i = 0; i < len; ++i) {
+ list.push_back(PySequence_Fast_GET_ITEM(seq, i));
+ }
+ Py_DECREF(seq);
+ return list;
+}
+
+
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
+ PyObject* target, PyObject* sources,
+ PyObject* output_gradients, TF_Status* status) {
+ PyVSpace c_vspace(vspace);
+ if (!c_vspace.Initialize().ok()) {
+ return nullptr;
+ }
+
+ std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
+ if (PyErr_Occurred()) {
+ return nullptr;
+ }
+ std::vector<tensorflow::int64> sources_vec = MakeTensorIDList(sources);
+ if (PyErr_Occurred()) {
+ return nullptr;
+ }
+ std::vector<PyObject*> outgrad_vec;
+ if (output_gradients != Py_None) {
+ outgrad_vec = MakeTensorList(output_gradients);
+ if (PyErr_Occurred()) {
+ return nullptr;
+ }
+ for (PyObject* tensor : outgrad_vec) {
+ // Calling the backward function will eat a reference to the tensors in
+ // outgrad_vec, so we need to increase their reference count.
+ Py_INCREF(tensor);
+ }
+ }
+ TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
+ std::vector<PyObject*> result;
+ status->status = tape_obj->tape->ComputeGradient(
+ c_vspace, target_vec, sources_vec, outgrad_vec, &result);
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+ if (!result.empty()) {
+ PyObject* py_result = PyList_New(result.size());
+ for (int i = 0; i < result.size(); ++i) {
+ if (result[i] == nullptr) {
+ Py_INCREF(Py_None);
+ result[i] = Py_None;
+ }
+ PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
+ }
+ return py_result;
+ }
+ Py_INCREF(Py_None);
+ return Py_None;
}
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index c16aa8c2f7..afbad183b0 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -23,7 +23,6 @@ import contextlib
import threading
from tensorflow.python import pywrap_tensorflow
-from tensorflow.python.util import compat
def tid(tensor):
@@ -72,7 +71,7 @@ class Tape(object):
True if any of the tensors is in the tape.
"""
return pywrap_tensorflow.TFE_Py_TapeShouldRecord(
- self._tape, [x._id for x in tensors]) # pylint: disable=protected-access
+ self._tape, tensors)
def watch(self, tensor):
"""Adds a tensor to the tape."""
@@ -87,9 +86,9 @@ class Tape(object):
"""Records an operation in the tape."""
pywrap_tensorflow.TFE_Py_TapeRecordOperation(
self._tape,
- compat.as_bytes(op_type),
+ op_type,
output_tensors,
- [x._id for x in input_tensors], # pylint: disable=protected-access
+ input_tensors,
backward_function)
def _delete_tensor_id(self, i):
@@ -99,16 +98,6 @@ class Tape(object):
"""Deletes any trace we have for this tensor."""
self._delete_tensor_id(tensor_id)
- def export(self):
- """Exports the internal state of this tape.
-
- Returns:
- tensor_tape: a map from tensor_id(tensor) to <identifier for op>
- responsible for generating that tensor.
- op_tape: a map from <identifier for op> to TapeEntry for that op.
- """
- return pywrap_tensorflow.TFE_Py_TapeExport(self._tape)
-
class _TapeStack(threading.local):
diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py
index c97cb62125..b490bac66d 100644
--- a/tensorflow/python/eager/tape_test.py
+++ b/tensorflow/python/eager/tape_test.py
@@ -22,7 +22,6 @@ from __future__ import print_function
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_gradient
-from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -166,25 +165,6 @@ class TapeTest(test.TestCase):
g, = backprop.gradients_function(fn, [0])(t)
self.assertAllEqual(g, 1.0)
- def testTapeGC(self):
- # TODO(apassos) figure out how to test this without using tape internal
- # APIs.
- tape.push_new_tape()
-
- def f():
- x = constant_op.constant(1.0)
- tape.watch(x)
- x = gradient_is_constant(x)
- x = gradient_is_constant(x)
- x = gradient_is_constant(x)
-
- f()
- t = tape.pop_tape()
- tensor_tape, op_tape = t.export()
- self.assertEqual(len(tensor_tape), 1) # The watched tensor will remain on
- # the tape
- self.assertEqual(len(op_tape), 0) # No operations should remain on the tape
-
def testCustomGradientGraphMode(self):
with context.graph_mode(), self.test_session():
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 26f1fd888a..03f386e9cf 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -25,6 +25,7 @@ py_library(
srcs = ["estimator_lib.py"],
srcs_version = "PY2AND3",
deps = [
+ ":baseline",
":dnn",
":dnn_linear_combined",
":estimator",
@@ -187,6 +188,68 @@ py_test(
)
py_library(
+ name = "baseline",
+ srcs = ["canned/baseline.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":estimator",
+ ":head",
+ ":model_fn",
+ ":optimizers",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:nn",
+ "//tensorflow/python:partitioned_variables",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/feature_column",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "baseline_test",
+ size = "medium",
+ srcs = ["canned/baseline_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "notsan", # b/67510291
+ ],
+ deps = [
+ ":baseline",
+ ":estimator",
+ ":export_export",
+ ":metric_keys",
+ ":numpy_io",
+ ":pandas_io",
+ ":run_config",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/feature_column",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "dnn",
srcs = ["canned/dnn.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/estimator/canned/baseline.py b/tensorflow/python/estimator/canned/baseline.py
new file mode 100644
index 0000000000..96e4ecd29f
--- /dev/null
+++ b/tensorflow/python/estimator/canned/baseline.py
@@ -0,0 +1,349 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Baseline estimators.
+
+Baseline estimators are bias-only estimators that can be used for debugging
+and as simple baselines.
+
+Example:
+
+```
+# Build BaselineClassifier
+classifier = BaselineClassifier(n_classes=3)
+
+# Input builders
+def input_fn_train: # returns x, y (where y represents label's class index).
+ pass
+
+def input_fn_eval: # returns x, y (where y represents label's class index).
+ pass
+
+# Fit model.
+classifier.train(input_fn=input_fn_train)
+
+# Evaluate cross entropy between the test and train labels.
+loss = classifier.evaluate(input_fn=input_fn_eval)["loss"]
+
+# predict outputs the probability distribution of the classes as seen in
+# training.
+predictions = classifier.predict(new_samples)
+```
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator.canned import head as head_lib
+from tensorflow.python.estimator.canned import optimizers
+from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import training_util
+
+# The default learning rate of 0.3 is a historical artifact of the initial
+# implementation, but seems a reasonable choice.
+_LEARNING_RATE = 0.3
+
+
+def _get_weight_column_key(weight_column):
+ if weight_column is None:
+ return None
+ if isinstance(weight_column, six.string_types):
+ return weight_column
+ if not isinstance(weight_column, feature_column_lib._NumericColumn): # pylint: disable=protected-access
+ raise TypeError('Weight column must be either a string or _NumericColumn.'
+ ' Given type: {}.'.format(type(weight_column)))
+ return weight_column.key()
+
+
+def _baseline_logit_fn_builder(num_outputs, weight_column=None):
+ """Function builder for a baseline logit_fn.
+
+ Args:
+ num_outputs: Number of outputs for the model.
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It will be multiplied by the loss of the example.
+ Returns:
+ A logit_fn (see below).
+ """
+
+ def baseline_logit_fn(features):
+ """Baseline model logit_fn.
+
+ The baseline model simply learns a bias, so the output logits are a
+ `Variable` with one weight for each output that learns the bias for the
+ corresponding output.
+
+ Args:
+ features: The first item returned from the `input_fn` passed to `train`,
+ `evaluate`, and `predict`. This should be a single `Tensor` or dict with
+ `Tensor` values.
+ Returns:
+ A `Tensor` representing the logits.
+ """
+ size_checks = []
+ batch_size = None
+
+ weight_column_key = _get_weight_column_key(weight_column)
+
+ # The first dimension is assumed to be a batch size and must be consistent
+ # among all of the features.
+ for key, feature in features.items():
+ # Skip weight_column to ensure we don't add size checks to it.
+ # These would introduce a dependency on the weight at serving time.
+ if key == weight_column_key:
+ continue
+ first_dim = array_ops.shape(feature)[0]
+ if batch_size is None:
+ batch_size = first_dim
+ else:
+ size_checks.append(check_ops.assert_equal(batch_size, first_dim))
+
+ with ops.control_dependencies(size_checks):
+ with variable_scope.variable_scope('baseline'):
+ bias = variable_scope.get_variable('bias', shape=[num_outputs],
+ initializer=init_ops.Zeros)
+ return math_ops.multiply(bias, array_ops.ones([batch_size,
+ num_outputs]))
+
+ return baseline_logit_fn
+
+
+def _baseline_model_fn(features, labels, mode, head, optimizer,
+ weight_column=None, config=None):
+ """Model_fn for baseline models.
+
+ Args:
+ features: `Tensor` or dict of `Tensor` (depends on data passed to `train`).
+ labels: `Tensor` of labels that are compatible with the `Head` instance.
+ mode: Defines whether this is training, evaluation or prediction.
+ See `ModeKeys`.
+ head: A `Head` instance.
+ optimizer: String, `tf.Optimizer` object, or callable that creates the
+ optimizer to use for training. If not specified, will use `FtrlOptimizer`
+ with a default learning rate of 0.3.
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It will be multiplied by the loss of the example.
+ config: `RunConfig` object to configure the runtime settings.
+
+ Raises:
+ KeyError: If weight column is specified but not present.
+ ValueError: If features is an empty dictionary.
+
+ Returns:
+ An `EstimatorSpec` instance.
+ """
+ del config # Unused.
+
+ logit_fn = _baseline_logit_fn_builder(head.logits_dimension, weight_column)
+ logits = logit_fn(features)
+
+ def train_op_fn(loss):
+ opt = optimizers.get_optimizer_instance(
+ optimizer, learning_rate=_LEARNING_RATE)
+ return opt.minimize(loss, global_step=training_util.get_global_step())
+
+ return head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ logits=logits,
+ labels=labels,
+ train_op_fn=train_op_fn)
+
+
+class BaselineClassifier(estimator.Estimator):
+ """A classifier that can establish a simple baseline.
+
+ This classifier ignores feature values and will learn to predict the average
+ value of each label. For single-label problems, this will predict the
+ probability distribution of the classes as seen in the labels. For multi-label
+ problems, this will predict the fraction of examples that are positive for
+ each class.
+
+ Example:
+
+ ```python
+
+ # Build BaselineClassifier
+ classifier = BaselineClassifier(n_classes=3)
+
+ # Input builders
+ def input_fn_train: # returns x, y (where y represents label's class index).
+ pass
+
+ def input_fn_eval: # returns x, y (where y represents label's class index).
+ pass
+
+ # Fit model.
+ classifier.train(input_fn=input_fn_train)
+
+ # Evaluate cross entropy between the test and train labels.
+ loss = classifier.evaluate(input_fn=input_fn_eval)["loss"]
+
+ # predict outputs the probability distribution of the classes as seen in
+ # training.
+ predictions = classifier.predict(new_samples)
+
+ ```
+
+ Input of `train` and `evaluate` should have following features,
+ otherwise there will be a `KeyError`:
+
+ * if `weight_column` is not `None`, a feature with
+ `key=weight_column` whose value is a `Tensor`.
+ """
+
+ def __init__(self,
+ model_dir=None,
+ n_classes=2,
+ weight_column=None,
+ label_vocabulary=None,
+ optimizer='Ftrl',
+ config=None):
+ """Initializes a BaselineClassifier instance.
+
+ Args:
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator to
+ continue training a previously saved model.
+ n_classes: number of label classes. Default is binary classification.
+ It must be greater than 1. Note: Class labels are integers representing
+ the class index (i.e. values from 0 to n_classes-1). For arbitrary
+ label values (e.g. string labels), convert to class indices first.
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It will be multiplied by the loss of the example.
+ label_vocabulary: Optional list of strings with size `[n_classes]`
+ defining the label vocabulary. Only supported for `n_classes` > 2.
+ optimizer: String, `tf.Optimizer` object, or callable that creates the
+ optimizer to use for training. If not specified, will use
+ `FtrlOptimizer` with a default learning rate of 0.3.
+ config: `RunConfig` object to configure the runtime settings.
+ Returns:
+ A `BaselineClassifier` estimator.
+
+ Raises:
+ ValueError: If `n_classes` < 2.
+ """
+ if n_classes == 2:
+ head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
+ weight_column=weight_column,
+ label_vocabulary=label_vocabulary)
+ else:
+ head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
+ n_classes, weight_column=weight_column,
+ label_vocabulary=label_vocabulary)
+ def _model_fn(features, labels, mode, config):
+ return _baseline_model_fn(
+ features=features,
+ labels=labels,
+ mode=mode,
+ head=head,
+ optimizer=optimizer,
+ weight_column=weight_column,
+ config=config)
+ super(BaselineClassifier, self).__init__(
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config)
+
+
+class BaselineRegressor(estimator.Estimator):
+ """A regressor that can establish a simple baseline.
+
+ This regressor ignores feature values and will learn to predict the average
+ value of each label.
+
+ Example:
+
+ ```python
+
+ # Build BaselineRegressor
+ regressor = BaselineRegressor()
+
+ # Input builders
+ def input_fn_train: # returns x, y (where y is the label).
+ pass
+
+ def input_fn_eval: # returns x, y (where y is the label).
+ pass
+
+ # Fit model.
+ regressor.train(input_fn=input_fn_train)
+
+ # Evaluate squared-loss between the test and train targets.
+ loss = regressor.evaluate(input_fn=input_fn_eval)["loss"]
+
+ # predict outputs the mean value seen during training.
+ predictions = regressor.predict(new_samples)
+ ```
+
+ Input of `train` and `evaluate` should have following features,
+ otherwise there will be a `KeyError`:
+
+ * if `weight_column` is not `None`, a feature with
+ `key=weight_column` whose value is a `Tensor`.
+ """
+
+ def __init__(self,
+ model_dir=None,
+ label_dimension=1,
+ weight_column=None,
+ optimizer='Ftrl',
+ config=None):
+ """Initializes a BaselineRegressor instance.
+
+ Args:
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator to
+ continue training a previously saved model.
+ label_dimension: Number of regression targets per example. This is the
+ size of the last dimension of the labels and logits `Tensor` objects
+ (typically, these have shape `[batch_size, label_dimension]`).
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It will be multiplied by the loss of the example.
+ optimizer: String, `tf.Optimizer` object, or callable that creates the
+ optimizer to use for training. If not specified, will use
+ `FtrlOptimizer` with a default learning rate of 0.3.
+ config: `RunConfig` object to configure the runtime settings.
+ Returns:
+ A `BaselineRegressor` estimator.
+ """
+
+ head = head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access
+ label_dimension=label_dimension,
+ weight_column=weight_column)
+ def _model_fn(features, labels, mode, config):
+ return _baseline_model_fn(
+ features=features,
+ labels=labels,
+ mode=mode,
+ head=head,
+ optimizer=optimizer,
+ config=config)
+ super(BaselineRegressor, self).__init__(
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config)
diff --git a/tensorflow/python/estimator/canned/baseline_test.py b/tensorflow/python/estimator/canned/baseline_test.py
new file mode 100644
index 0000000000..96639e88ea
--- /dev/null
+++ b/tensorflow/python/estimator/canned/baseline_test.py
@@ -0,0 +1,1545 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for baseline.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import os
+import shutil
+import tempfile
+
+import numpy as np
+import six
+
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.client import session as tf_session
+from tensorflow.python.estimator.canned import baseline
+from tensorflow.python.estimator.canned import metric_keys
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.estimator.inputs import pandas_io
+from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import input as input_lib
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import queue_runner
+from tensorflow.python.training import saver
+
+
+try:
+ # pylint: disable=g-import-not-at-top
+ import pandas as pd
+ HAS_PANDAS = True
+except IOError:
+ # Pandas writes a temporary file during import. If it fails, don't use pandas.
+ HAS_PANDAS = False
+except ImportError:
+ HAS_PANDAS = False
+
+# pylint rules which are disabled by default for test files.
+# pylint: disable=invalid-name,protected-access,missing-docstring
+
+# Names of variables created by model.
+BIAS_NAME = 'baseline/bias'
+
+
+def assert_close(expected, actual, rtol=1e-04, name='assert_close'):
+ with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:
+ expected = ops.convert_to_tensor(expected, name='expected')
+ actual = ops.convert_to_tensor(actual, name='actual')
+ rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected)
+ rtol = ops.convert_to_tensor(rtol, name='rtol')
+ return check_ops.assert_less(
+ rdiff,
+ rtol,
+ data=('Condition expected =~ actual did not hold element-wise:'
+ 'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff,
+ 'rtol = ', rtol,),
+ name=scope)
+
+
+def save_variables_to_ckpt(model_dir):
+ init_all_op = [variables.global_variables_initializer()]
+ with tf_session.Session() as sess:
+ sess.run(init_all_op)
+ saver.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
+
+
+def queue_parsed_features(feature_map):
+ tensors_to_enqueue = []
+ keys = []
+ for key, tensor in six.iteritems(feature_map):
+ keys.append(key)
+ tensors_to_enqueue.append(tensor)
+ queue_dtypes = [x.dtype for x in tensors_to_enqueue]
+ input_queue = data_flow_ops.FIFOQueue(capacity=100, dtypes=queue_dtypes)
+ queue_runner.add_queue_runner(
+ queue_runner.QueueRunner(input_queue,
+ [input_queue.enqueue(tensors_to_enqueue)]))
+ dequeued_tensors = input_queue.dequeue()
+ return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}
+
+
+def sorted_key_dict(unsorted_dict):
+ return {k: unsorted_dict[k] for k in sorted(unsorted_dict)}
+
+
+def sigmoid(x):
+ return 1 / (1 + np.exp(-1.0 * x))
+
+
+def _baseline_regressor_fn(*args, **kwargs):
+ return baseline.BaselineRegressor(*args, **kwargs)
+
+
+def _baseline_classifier_fn(*args, **kwargs):
+ return baseline.BaselineClassifier(*args, **kwargs)
+
+
+# Tests for Baseline Regressor.
+
+
+# TODO(b/36813849): Add tests with dynamic shape inputs using placeholders.
+class BaselineRegressorEvaluationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def test_evaluation_for_simple_data(self):
+ with ops.Graph().as_default():
+ variables.Variable([13.0], name=BIAS_NAME)
+ variables.Variable(
+ 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)
+ eval_metrics = baseline_regressor.evaluate(
+ input_fn=lambda: ({'age': ((1,),)}, ((10.,),)), steps=1)
+
+ # Logit is bias = 13, while label is 10. Loss is 3**2 = 9.
+ self.assertDictEqual({
+ metric_keys.MetricKeys.LOSS: 9.,
+ metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ ops.GraphKeys.GLOBAL_STEP: 100
+ }, eval_metrics)
+
+ def test_evaluation_batch(self):
+ """Tests evaluation for batch_size==2."""
+ with ops.Graph().as_default():
+ variables.Variable([13.0], name=BIAS_NAME)
+ variables.Variable(
+ 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)
+ eval_metrics = baseline_regressor.evaluate(
+ input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1)
+
+ # Logit is bias = 13, while label is 10.
+ # Loss per example is 3**2 = 9.
+ # Training loss is the sum over batch = 9 + 9 = 18
+ # Average loss is the average over batch = 9
+ self.assertDictEqual({
+ metric_keys.MetricKeys.LOSS: 18.,
+ metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ ops.GraphKeys.GLOBAL_STEP: 100
+ }, eval_metrics)
+
+ def test_evaluation_weights(self):
+ """Tests evaluation with weights."""
+ with ops.Graph().as_default():
+ variables.Variable([13.0], name=BIAS_NAME)
+ variables.Variable(
+ 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ def _input_fn():
+ features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))}
+ labels = ((10.,), (10.,))
+ return features, labels
+
+ baseline_regressor = _baseline_regressor_fn(
+ weight_column='weights',
+ model_dir=self._model_dir)
+ eval_metrics = baseline_regressor.evaluate(input_fn=_input_fn, steps=1)
+
+ # Logit is bias = 13, while label is 10.
+ # Loss per example is 3**2 = 9.
+ # Training loss is the weighted sum over batch = 9 + 2*9 = 27
+ # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9
+ self.assertDictEqual({
+ metric_keys.MetricKeys.LOSS: 27.,
+ metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ ops.GraphKeys.GLOBAL_STEP: 100
+ }, eval_metrics)
+
+ def test_evaluation_for_multi_dimensions(self):
+ label_dim = 2
+ with ops.Graph().as_default():
+ variables.Variable([46.0, 58.0], name=BIAS_NAME)
+ variables.Variable(100, name='global_step', dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ baseline_regressor = _baseline_regressor_fn(
+ label_dimension=label_dim,
+ model_dir=self._model_dir)
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'age': np.array([[2., 4., 5.]]),
+ },
+ y=np.array([[46., 58.]]),
+ batch_size=1,
+ num_epochs=None,
+ shuffle=False)
+ eval_metrics = baseline_regressor.evaluate(input_fn=input_fn, steps=1)
+
+ self.assertItemsEqual(
+ (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
+ ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys())
+
+ # Logit is bias which is [46, 58]
+ self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])
+
+
+class BaselineRegressorPredictTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def test_1d(self):
+ """Tests predict when all variables are one-dimensional."""
+ with ops.Graph().as_default():
+ variables.Variable([.2], name=BIAS_NAME)
+ variables.Variable(100, name='global_step', dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)
+
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': np.array([[2.]])},
+ y=None,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False)
+ predictions = baseline_regressor.predict(input_fn=predict_input_fn)
+ predicted_scores = list([x['predictions'] for x in predictions])
+ # x * weight + bias = 2. * 10. + .2 = 20.2
+ self.assertAllClose([[.2]], predicted_scores)
+
+ def testMultiDim(self):
+ """Tests predict when all variables are multi-dimenstional."""
+ batch_size = 2
+ label_dimension = 3
+ with ops.Graph().as_default():
+ variables.Variable( # shape=[label_dimension]
+ [.2, .4, .6], name=BIAS_NAME)
+ variables.Variable(100, name='global_step', dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ baseline_regressor = _baseline_regressor_fn(
+ label_dimension=label_dimension,
+ model_dir=self._model_dir)
+
+ predict_input_fn = numpy_io.numpy_input_fn(
+ # x shape=[batch_size, x_dim]
+ x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])},
+ y=None,
+ batch_size=batch_size,
+ num_epochs=1,
+ shuffle=False)
+ predictions = baseline_regressor.predict(input_fn=predict_input_fn)
+ predicted_scores = list([x['predictions'] for x in predictions])
+ # score = bias, shape=[batch_size, label_dimension]
+ self.assertAllClose([[0.2, 0.4, 0.6], [0.2, 0.4, 0.6]],
+ predicted_scores)
+
+
+class BaselineRegressorIntegrationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, label_dimension, prediction_length):
+ feature_columns = [
+ feature_column_lib.numeric_column('x', shape=(input_dimension,))
+ ]
+ est = _baseline_regressor_fn(
+ label_dimension=label_dimension,
+ model_dir=self._model_dir)
+
+ # TRAIN
+ # learn y = x
+ est.train(train_input_fn, steps=200)
+
+ # EVALUTE
+ scores = est.evaluate(eval_input_fn)
+ self.assertEqual(200, scores[ops.GraphKeys.GLOBAL_STEP])
+ self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))
+
+ # PREDICT
+ predictions = np.array(
+ [x['predictions'] for x in est.predict(predict_input_fn)])
+ self.assertAllEqual((prediction_length, label_dimension), predictions.shape)
+
+ # EXPORT
+ feature_spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+ serving_input_receiver_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ def test_numpy_input_fn(self):
+ """Tests complete flow with numpy_input_fn."""
+ label_dimension = 2
+ input_dimension = label_dimension
+ batch_size = 10
+ prediction_length = batch_size
+ data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, label_dimension)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=data,
+ batch_size=batch_size,
+ num_epochs=1,
+ shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=None,
+ batch_size=batch_size,
+ num_epochs=1,
+ shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=input_dimension,
+ label_dimension=label_dimension,
+ prediction_length=prediction_length)
+
+ def test_pandas_input_fn(self):
+ """Tests complete flow with pandas_input_fn."""
+ if not HAS_PANDAS:
+ return
+
+ # Pandas DataFrame natually supports 1 dim data only.
+ label_dimension = 1
+ input_dimension = label_dimension
+ batch_size = 10
+ data = np.array([1., 2., 3., 4.], dtype=np.float32)
+ x = pd.DataFrame({'x': data})
+ y = pd.Series(data)
+ prediction_length = 4
+
+ train_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)
+ eval_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, shuffle=False)
+ predict_input_fn = pandas_io.pandas_input_fn(
+ x=x, batch_size=batch_size, shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=input_dimension,
+ label_dimension=label_dimension,
+ prediction_length=prediction_length)
+
+ def test_input_fn_from_parse_example(self):
+ """Tests complete flow with input_fn constructed from parse_example."""
+ label_dimension = 2
+ input_dimension = label_dimension
+ batch_size = 10
+ prediction_length = batch_size
+ data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, label_dimension)
+
+ serialized_examples = []
+ for datum in data:
+ example = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'x':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=datum)),
+ 'y':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=datum[:label_dimension])),
+ }))
+ serialized_examples.append(example.SerializeToString())
+
+ feature_spec = {
+ 'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
+ 'y': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
+ }
+
+ def _train_input_fn():
+ feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
+ features = queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _eval_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _predict_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = queue_parsed_features(feature_map)
+ features.pop('y')
+ return features, None
+
+ self._test_complete_flow(
+ train_input_fn=_train_input_fn,
+ eval_input_fn=_eval_input_fn,
+ predict_input_fn=_predict_input_fn,
+ input_dimension=input_dimension,
+ label_dimension=label_dimension,
+ prediction_length=prediction_length)
+
+
+class BaselineRegressorTrainingTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _mock_optimizer(self, expected_loss=None):
+ expected_var_names = [
+ '%s:0' % BIAS_NAME
+ ]
+
+ def _minimize(loss, global_step=None, var_list=None):
+ trainable_vars = var_list or ops.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertItemsEqual(expected_var_names,
+ [var.name for var in trainable_vars])
+
+ # Verify loss. We can't check the value directly, so we add an assert op.
+ self.assertEquals(0, loss.shape.ndims)
+ if expected_loss is None:
+ if global_step is not None:
+ return state_ops.assign_add(global_step, 1).op
+ return control_flow_ops.no_op()
+ assert_loss = assert_close(
+ math_ops.to_float(expected_loss, name='expected'),
+ loss,
+ name='assert_loss')
+ with ops.control_dependencies((assert_loss,)):
+ if global_step is not None:
+ return state_ops.assign_add(global_step, 1).op
+ return control_flow_ops.no_op()
+
+ mock_optimizer = test.mock.NonCallableMock(
+ spec=optimizer.Optimizer,
+ wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer'))
+ mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize)
+
+ # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks.
+ # So, return mock_optimizer itself for deepcopy.
+ mock_optimizer.__deepcopy__ = lambda _: mock_optimizer
+ return mock_optimizer
+
+ def _assert_checkpoint(self,
+ label_dimension,
+ expected_global_step,
+ expected_bias=None):
+ shapes = {
+ name: shape
+ for (name, shape) in checkpoint_utils.list_variables(self._model_dir)
+ }
+
+ self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
+ self.assertEqual(expected_global_step,
+ checkpoint_utils.load_variable(self._model_dir,
+ ops.GraphKeys.GLOBAL_STEP))
+
+ self.assertEqual([label_dimension], shapes[BIAS_NAME])
+ if expected_bias is not None:
+ self.assertEqual(expected_bias,
+ checkpoint_utils.load_variable(self._model_dir,
+ BIAS_NAME))
+
+ def testFromScratchWithDefaultOptimizer(self):
+ # Create BaselineRegressor.
+ label = 5.
+ age = 17
+ baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)
+
+ # Train for a few steps, and validate final checkpoint.
+ num_steps = 10
+ baseline_regressor.train(
+ input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+ self._assert_checkpoint(label_dimension=1, expected_global_step=num_steps)
+
+ def testTrainWithOneDimLabel(self):
+ label_dimension = 1
+ batch_size = 20
+ est = _baseline_regressor_fn(
+ label_dimension=label_dimension,
+ model_dir=self._model_dir)
+ data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32)
+ self.assertEqual((batch_size,), data_rank_1.shape)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'age': data_rank_1},
+ y=data_rank_1,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ est.train(train_input_fn, steps=200)
+ self._assert_checkpoint(label_dimension=1, expected_global_step=200)
+
+ def testTrainWithOneDimWeight(self):
+ label_dimension = 1
+ batch_size = 20
+ est = _baseline_regressor_fn(
+ label_dimension=label_dimension,
+ weight_column='w',
+ model_dir=self._model_dir)
+
+ data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32)
+ self.assertEqual((batch_size,), data_rank_1.shape)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'age': data_rank_1,
+ 'w': data_rank_1},
+ y=data_rank_1,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ est.train(train_input_fn, steps=200)
+ self._assert_checkpoint(label_dimension=1, expected_global_step=200)
+
+ def testFromScratch(self):
+ # Create BaselineRegressor.
+ label = 5.
+ age = 17
+ # loss = (logits - label)^2 = (0 - 5.)^2 = 25.
+ mock_optimizer = self._mock_optimizer(expected_loss=25.)
+ baseline_regressor = _baseline_regressor_fn(
+ model_dir=self._model_dir,
+ optimizer=mock_optimizer)
+ self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+ # Train for a few steps, and validate optimizer and final checkpoint.
+ num_steps = 10
+ baseline_regressor.train(
+ input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+ self.assertEqual(1, mock_optimizer.minimize.call_count)
+ self._assert_checkpoint(
+ label_dimension=1,
+ expected_global_step=num_steps,
+ expected_bias=[0.])
+
+ def testFromCheckpoint(self):
+ # Create initial checkpoint.
+ bias = 7.0
+ initial_global_step = 100
+ with ops.Graph().as_default():
+ variables.Variable([bias], name=BIAS_NAME)
+ variables.Variable(
+ initial_global_step,
+ name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ # logits = bias = 6.
+ # loss = (logits - label)^2 = (7 - 5)^2 = 4
+ mock_optimizer = self._mock_optimizer(expected_loss=4.)
+ baseline_regressor = _baseline_regressor_fn(
+ model_dir=self._model_dir,
+ optimizer=mock_optimizer)
+ self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+ # Train for a few steps, and validate optimizer and final checkpoint.
+ num_steps = 10
+ baseline_regressor.train(
+ input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps)
+ self.assertEqual(1, mock_optimizer.minimize.call_count)
+ self._assert_checkpoint(
+ label_dimension=1,
+ expected_global_step=initial_global_step + num_steps,
+ expected_bias=[bias])
+
+ def testFromCheckpointMultiBatch(self):
+ # Create initial checkpoint.
+ bias = 5.0
+ initial_global_step = 100
+ with ops.Graph().as_default():
+ variables.Variable([bias], name=BIAS_NAME)
+ variables.Variable(
+ initial_global_step,
+ name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ # logits = bias
+ # logits[0] = 5.
+ # logits[1] = 5.
+ # loss = sum(logits - label)^2 = (5 - 5)^2 + (5 - 3)^2 = 4
+ mock_optimizer = self._mock_optimizer(expected_loss=4.)
+ baseline_regressor = _baseline_regressor_fn(
+ model_dir=self._model_dir,
+ optimizer=mock_optimizer)
+ self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+ # Train for a few steps, and validate optimizer and final checkpoint.
+ num_steps = 10
+ baseline_regressor.train(
+ input_fn=lambda: ({'age': ((17,), (15,))}, ((5.,), (3.,))),
+ steps=num_steps)
+ self.assertEqual(1, mock_optimizer.minimize.call_count)
+ self._assert_checkpoint(
+ label_dimension=1,
+ expected_global_step=initial_global_step + num_steps,
+ expected_bias=bias)
+
+
+# Tests for Baseline Classifier.
+
+
+class BaselineClassifierTrainingTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ shutil.rmtree(self._model_dir)
+
+ def _mock_optimizer(self, expected_loss=None):
+ expected_var_names = [
+ '%s:0' % BIAS_NAME
+ ]
+
+ def _minimize(loss, global_step):
+ trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertItemsEqual(
+ expected_var_names,
+ [var.name for var in trainable_vars])
+
+ # Verify loss. We can't check the value directly, so we add an assert op.
+ self.assertEquals(0, loss.shape.ndims)
+ if expected_loss is None:
+ return state_ops.assign_add(global_step, 1).op
+ assert_loss = assert_close(
+ math_ops.to_float(expected_loss, name='expected'),
+ loss,
+ name='assert_loss')
+ with ops.control_dependencies((assert_loss,)):
+ return state_ops.assign_add(global_step, 1).op
+
+ mock_optimizer = test.mock.NonCallableMock(
+ spec=optimizer.Optimizer,
+ wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer'))
+ mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize)
+
+ # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks.
+ # So, return mock_optimizer itself for deepcopy.
+ mock_optimizer.__deepcopy__ = lambda _: mock_optimizer
+ return mock_optimizer
+
+ def _assert_checkpoint(
+ self, n_classes, expected_global_step, expected_bias=None):
+ logits_dimension = n_classes if n_classes > 2 else 1
+
+ shapes = {
+ name: shape for (name, shape) in
+ checkpoint_utils.list_variables(self._model_dir)
+ }
+
+ self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
+ self.assertEqual(
+ expected_global_step,
+ checkpoint_utils.load_variable(
+ self._model_dir, ops.GraphKeys.GLOBAL_STEP))
+
+ self.assertEqual([logits_dimension], shapes[BIAS_NAME])
+ if expected_bias is not None:
+ self.assertAllEqual(expected_bias,
+ checkpoint_utils.load_variable(
+ self._model_dir, BIAS_NAME))
+
+ def _testFromScratchWithDefaultOptimizer(self, n_classes):
+ label = 0
+ age = 17
+ est = baseline.BaselineClassifier(
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+
+ # Train for a few steps, and validate final checkpoint.
+ num_steps = 10
+ est.train(
+ input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+ self._assert_checkpoint(n_classes, num_steps)
+
+ def testBinaryClassesFromScratchWithDefaultOptimizer(self):
+ self._testFromScratchWithDefaultOptimizer(n_classes=2)
+
+ def testMultiClassesFromScratchWithDefaultOptimizer(self):
+ self._testFromScratchWithDefaultOptimizer(n_classes=4)
+
+ def _testTrainWithTwoDimsLabel(self, n_classes):
+ batch_size = 20
+
+ est = baseline.BaselineClassifier(
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+ data_rank_1 = np.array([0, 1])
+ data_rank_2 = np.array([[0], [1]])
+ self.assertEqual((2,), data_rank_1.shape)
+ self.assertEqual((2, 1), data_rank_2.shape)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'age': data_rank_1},
+ y=data_rank_2,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ est.train(train_input_fn, steps=200)
+ self._assert_checkpoint(n_classes, 200)
+
+ def testBinaryClassesTrainWithTwoDimsLabel(self):
+ self._testTrainWithTwoDimsLabel(n_classes=2)
+
+ def testMultiClassesTrainWithTwoDimsLabel(self):
+ self._testTrainWithTwoDimsLabel(n_classes=4)
+
+ def _testTrainWithOneDimLabel(self, n_classes):
+ batch_size = 20
+
+ est = baseline.BaselineClassifier(
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+ data_rank_1 = np.array([0, 1])
+ self.assertEqual((2,), data_rank_1.shape)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'age': data_rank_1},
+ y=data_rank_1,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ est.train(train_input_fn, steps=200)
+ self._assert_checkpoint(n_classes, 200)
+
+ def testBinaryClassesTrainWithOneDimLabel(self):
+ self._testTrainWithOneDimLabel(n_classes=2)
+
+ def testMultiClassesTrainWithOneDimLabel(self):
+ self._testTrainWithOneDimLabel(n_classes=4)
+
+ def _testTrainWithTwoDimsWeight(self, n_classes):
+ batch_size = 20
+
+ est = baseline.BaselineClassifier(
+ weight_column='w',
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+ data_rank_1 = np.array([0, 1])
+ data_rank_2 = np.array([[0], [1]])
+ self.assertEqual((2,), data_rank_1.shape)
+ self.assertEqual((2, 1), data_rank_2.shape)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'age': data_rank_1, 'w': data_rank_2}, y=data_rank_1,
+ batch_size=batch_size, num_epochs=None,
+ shuffle=True)
+ est.train(train_input_fn, steps=200)
+ self._assert_checkpoint(n_classes, 200)
+
+ def testBinaryClassesTrainWithTwoDimsWeight(self):
+ self._testTrainWithTwoDimsWeight(n_classes=2)
+
+ def testMultiClassesTrainWithTwoDimsWeight(self):
+ self._testTrainWithTwoDimsWeight(n_classes=4)
+
+ def _testTrainWithOneDimWeight(self, n_classes):
+ batch_size = 20
+
+ est = baseline.BaselineClassifier(
+ weight_column='w',
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+ data_rank_1 = np.array([0, 1])
+ self.assertEqual((2,), data_rank_1.shape)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'age': data_rank_1, 'w': data_rank_1}, y=data_rank_1,
+ batch_size=batch_size, num_epochs=None,
+ shuffle=True)
+ est.train(train_input_fn, steps=200)
+ self._assert_checkpoint(n_classes, 200)
+
+ def testBinaryClassesTrainWithOneDimWeight(self):
+ self._testTrainWithOneDimWeight(n_classes=2)
+
+ def testMultiClassesTrainWithOneDimWeight(self):
+ self._testTrainWithOneDimWeight(n_classes=4)
+
+ def _testFromScratch(self, n_classes):
+ label = 1
+ age = 17
+ # For binary classifier:
+ # loss = sigmoid_cross_entropy(logits, label) where logits=0 (weights are
+ # all zero initially) and label = 1 so,
+ # loss = 1 * -log ( sigmoid(logits) ) = 0.69315
+ # For multi class classifier:
+ # loss = cross_entropy(logits, label) where logits are all 0s (weights are
+ # all zero initially) and label = 1 so,
+ # loss = 1 * -log ( 1.0 / n_classes )
+ # For this particular test case, as logits are same, the formula
+ # 1 * -log ( 1.0 / n_classes ) covers both binary and multi class cases.
+ mock_optimizer = self._mock_optimizer(
+ expected_loss=-1 * math.log(1.0/n_classes))
+
+ est = baseline.BaselineClassifier(
+ n_classes=n_classes,
+ optimizer=mock_optimizer,
+ model_dir=self._model_dir)
+ self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+ # Train for a few steps, and validate optimizer and final checkpoint.
+ num_steps = 10
+ est.train(
+ input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+ self.assertEqual(1, mock_optimizer.minimize.call_count)
+ self._assert_checkpoint(
+ n_classes,
+ expected_global_step=num_steps,
+ expected_bias=[0.] if n_classes == 2 else [.0] * n_classes)
+
+ def testBinaryClassesFromScratch(self):
+ self._testFromScratch(n_classes=2)
+
+ def testMultiClassesFromScratch(self):
+ self._testFromScratch(n_classes=4)
+
+ def _testFromCheckpoint(self, n_classes):
+ # Create initial checkpoint.
+ label = 1
+ age = 17
+ bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes
+ initial_global_step = 100
+ with ops.Graph().as_default():
+ variables.Variable(bias, name=BIAS_NAME)
+ variables.Variable(
+ initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ # For binary classifier:
+ # logits = bias = -1.
+ # loss = sigmoid_cross_entropy(logits, label)
+ # so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133
+ # For multi class classifier:
+ # loss = cross_entropy(logits, label)
+ # where logits = bias and label = 1
+ # so, loss = 1 * -log ( softmax(logits)[1] )
+ if n_classes == 2:
+ expected_loss = 1.3133
+ else:
+ logits = bias
+ logits_exp = np.exp(logits)
+ softmax = logits_exp / logits_exp.sum()
+ expected_loss = -1 * math.log(softmax[label])
+
+ mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
+
+ est = baseline.BaselineClassifier(
+ n_classes=n_classes,
+ optimizer=mock_optimizer,
+ model_dir=self._model_dir)
+ self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+ # Train for a few steps, and validate optimizer and final checkpoint.
+ num_steps = 10
+ est.train(
+ input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+ self.assertEqual(1, mock_optimizer.minimize.call_count)
+ self._assert_checkpoint(
+ n_classes,
+ expected_global_step=initial_global_step + num_steps,
+ expected_bias=bias)
+
+ def testBinaryClassesFromCheckpoint(self):
+ self._testFromCheckpoint(n_classes=2)
+
+ def testMultiClassesFromCheckpoint(self):
+ self._testFromCheckpoint(n_classes=4)
+
+ def _testFromCheckpointFloatLabels(self, n_classes):
+ """Tests float labels for binary classification."""
+ # Create initial checkpoint.
+ if n_classes > 2:
+ return
+ label = 0.8
+ age = 17
+ bias = [-1.0]
+ initial_global_step = 100
+ with ops.Graph().as_default():
+ variables.Variable(bias, name=BIAS_NAME)
+ variables.Variable(
+ initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ # logits = bias = -1.
+ # loss = sigmoid_cross_entropy(logits, label)
+ # => loss = -0.8 * log(sigmoid(-1)) -0.2 * log(sigmoid(+1)) = 1.1132617
+ mock_optimizer = self._mock_optimizer(expected_loss=1.1132617)
+
+ est = baseline.BaselineClassifier(
+ n_classes=n_classes,
+ optimizer=mock_optimizer,
+ model_dir=self._model_dir)
+ self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+ # Train for a few steps, and validate optimizer and final checkpoint.
+ num_steps = 10
+ est.train(
+ input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+ self.assertEqual(1, mock_optimizer.minimize.call_count)
+
+ def testBinaryClassesFromCheckpointFloatLabels(self):
+ self._testFromCheckpointFloatLabels(n_classes=2)
+
+ def testMultiClassesFromCheckpointFloatLabels(self):
+ self._testFromCheckpointFloatLabels(n_classes=4)
+
+ def _testFromCheckpointMultiBatch(self, n_classes):
+ # Create initial checkpoint.
+ label = [1, 0]
+ age = [17, 18.5]
+ # For binary case, the expected weight has shape (1,1). For multi class
+ # case, the shape is (1, n_classes). In order to test the weights, set
+ # weights as 2.0 * range(n_classes).
+ bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes
+ initial_global_step = 100
+ with ops.Graph().as_default():
+ variables.Variable(bias, name=BIAS_NAME)
+ variables.Variable(
+ initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ # For binary classifier:
+ # logits = bias
+ # logits[0] = -1.
+ # logits[1] = -1.
+ # loss = sigmoid_cross_entropy(logits, label)
+ # so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133
+ # loss[1] = (1 - 0) * -log ( 1- sigmoid(-1) ) = 0.3132
+ # For multi class classifier:
+ # loss = cross_entropy(logits, label)
+ # where logits = bias and label = [1, 0]
+ # so, loss = 1 * -log ( softmax(logits)[label] )
+ if n_classes == 2:
+ expected_loss = (1.3133 + 0.3132)
+ else:
+ # Expand logits since batch_size=2
+ logits = bias * np.ones(shape=(2, 1))
+ logits_exp = np.exp(logits)
+ softmax_row_0 = logits_exp[0] / logits_exp[0].sum()
+ softmax_row_1 = logits_exp[1] / logits_exp[1].sum()
+ expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])
+ expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])
+ expected_loss = expected_loss_0 + expected_loss_1
+
+ mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
+
+ est = baseline.BaselineClassifier(
+ n_classes=n_classes,
+ optimizer=mock_optimizer,
+ model_dir=self._model_dir)
+ self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+ # Train for a few steps, and validate optimizer and final checkpoint.
+ num_steps = 10
+ est.train(
+ input_fn=lambda: ({'age': (age)}, (label)),
+ steps=num_steps)
+ self.assertEqual(1, mock_optimizer.minimize.call_count)
+ self._assert_checkpoint(
+ n_classes,
+ expected_global_step=initial_global_step + num_steps,
+ expected_bias=bias)
+
+ def testBinaryClassesFromCheckpointMultiBatch(self):
+ self._testFromCheckpointMultiBatch(n_classes=2)
+
+ def testMultiClassesFromCheckpointMultiBatch(self):
+ self._testFromCheckpointMultiBatch(n_classes=4)
+
+
+class BaselineClassifierEvaluationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ shutil.rmtree(self._model_dir)
+
+ def _test_evaluation_for_simple_data(self, n_classes):
+ label = 1
+ age = 1.
+
+ bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes
+
+ with ops.Graph().as_default():
+ variables.Variable(bias, name=BIAS_NAME)
+ variables.Variable(
+ 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ est = _baseline_classifier_fn(
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+ eval_metrics = est.evaluate(
+ input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=1)
+
+ if n_classes == 2:
+ # Binary classes: loss = -log(sigmoid(-1)) = 1.3133
+ # Prediction = sigmoid(-1) = 0.2689
+ expected_metrics = {
+ metric_keys.MetricKeys.LOSS: 1.3133,
+ ops.GraphKeys.GLOBAL_STEP: 100,
+ metric_keys.MetricKeys.LOSS_MEAN: 1.3133,
+ metric_keys.MetricKeys.ACCURACY: 0.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 0.2689,
+ metric_keys.MetricKeys.LABEL_MEAN: 1.,
+ metric_keys.MetricKeys.ACCURACY_BASELINE: 1,
+ metric_keys.MetricKeys.AUC: 0.,
+ metric_keys.MetricKeys.AUC_PR: 1.,
+ }
+ else:
+ # Multi classes: loss = 1 * -log ( softmax(logits)[label] )
+ logits = bias
+ logits_exp = np.exp(logits)
+ softmax = logits_exp / logits_exp.sum()
+ expected_loss = -1 * math.log(softmax[label])
+
+ expected_metrics = {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ ops.GraphKeys.GLOBAL_STEP: 100,
+ metric_keys.MetricKeys.LOSS_MEAN: expected_loss,
+ metric_keys.MetricKeys.ACCURACY: 0.,
+ }
+
+ self.assertAllClose(sorted_key_dict(expected_metrics),
+ sorted_key_dict(eval_metrics), rtol=1e-3)
+
+ def test_binary_classes_evaluation_for_simple_data(self):
+ self._test_evaluation_for_simple_data(n_classes=2)
+
+ def test_multi_classes_evaluation_for_simple_data(self):
+ self._test_evaluation_for_simple_data(n_classes=4)
+
+ def _test_evaluation_batch(self, n_classes):
+ """Tests evaluation for batch_size==2."""
+ label = [1, 0]
+ age = [17., 18.]
+ bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes
+ initial_global_step = 100
+ with ops.Graph().as_default():
+ variables.Variable(bias, name=BIAS_NAME)
+ variables.Variable(
+ initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ est = _baseline_classifier_fn(
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+ eval_metrics = est.evaluate(
+ input_fn=lambda: ({'age': (age)}, (label)), steps=1)
+
+ if n_classes == 2:
+ # Logits are (-1., -1.) labels are (1, 0).
+ # Loss is
+ # loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133
+ # loss for row 2: (1 - 0) * -log(1 - sigmoid(-1)) = 0.3132
+ # Prediction = sigmoid(-1) = 0.2689
+ expected_loss = 1.3133 + 0.3132
+
+ expected_metrics = {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ ops.GraphKeys.GLOBAL_STEP: 100,
+ metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
+ metric_keys.MetricKeys.ACCURACY: 0.5,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 0.2689,
+ metric_keys.MetricKeys.LABEL_MEAN: 0.5,
+ metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
+ metric_keys.MetricKeys.AUC: 0.5,
+ metric_keys.MetricKeys.AUC_PR: 0.75,
+ }
+ else:
+ # Expand logits since batch_size=2
+ logits = bias * np.ones(shape=(2, 1))
+ logits_exp = np.exp(logits)
+ softmax_row_0 = logits_exp[0] / logits_exp[0].sum()
+ softmax_row_1 = logits_exp[1] / logits_exp[1].sum()
+ expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])
+ expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])
+ expected_loss = expected_loss_0 + expected_loss_1
+
+ expected_metrics = {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ ops.GraphKeys.GLOBAL_STEP: 100,
+ metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
+ metric_keys.MetricKeys.ACCURACY: 0.5,
+ }
+
+ self.assertAllClose(sorted_key_dict(expected_metrics),
+ sorted_key_dict(eval_metrics), rtol=1e-3)
+
+ def test_binary_classes_evaluation_batch(self):
+ self._test_evaluation_batch(n_classes=2)
+
+ def test_multi_classes_evaluation_batch(self):
+ self._test_evaluation_batch(n_classes=4)
+
+ def _test_evaluation_weights(self, n_classes):
+ """Tests evaluation with weights."""
+
+ label = [1, 0]
+ age = [17., 18.]
+ weights = [1., 2.]
+ # For binary case, the expected weight has shape (1,1). For multi class
+ # case, the shape is (1, n_classes). In order to test the weights, set
+ # weights as 2.0 * range(n_classes).
+ bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes
+ initial_global_step = 100
+ with ops.Graph().as_default():
+ variables.Variable(bias, name=BIAS_NAME)
+ variables.Variable(
+ initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ est = _baseline_classifier_fn(
+ n_classes=n_classes,
+ weight_column='w',
+ model_dir=self._model_dir)
+ eval_metrics = est.evaluate(
+ input_fn=lambda: ({'age': (age), 'w': (weights)}, (label)), steps=1)
+
+ if n_classes == 2:
+ # Logits are (-1., -1.) labels are (1, 0).
+ # Loss is
+ # loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133
+ # loss for row 2: (1 - 0) * -log(1 - sigmoid(-1)) = 0.3132
+ # weights = [1., 2.]
+ expected_loss = 1.3133 * 1. + 0.3132 * 2.
+ loss_mean = expected_loss / (1.0 + 2.0)
+ label_mean = np.average(label, weights=weights)
+ logits = [-1, -1]
+ logistics = sigmoid(np.array(logits))
+ predictions_mean = np.average(logistics, weights=weights)
+
+ expected_metrics = {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ ops.GraphKeys.GLOBAL_STEP: 100,
+ metric_keys.MetricKeys.LOSS_MEAN: loss_mean,
+ metric_keys.MetricKeys.ACCURACY: 2. / (1. + 2.),
+ metric_keys.MetricKeys.PREDICTION_MEAN: predictions_mean,
+ metric_keys.MetricKeys.LABEL_MEAN: label_mean,
+ metric_keys.MetricKeys.ACCURACY_BASELINE: (
+ max(label_mean, 1-label_mean)),
+ metric_keys.MetricKeys.AUC: 0.5,
+ metric_keys.MetricKeys.AUC_PR: 2. / (1. + 2.),
+ }
+ else:
+ # Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] )
+ # Expand logits since batch_size=2
+ logits = bias * np.ones(shape=(2, 1))
+ logits_exp = np.exp(logits)
+ softmax_row_0 = logits_exp[0] / logits_exp[0].sum()
+ softmax_row_1 = logits_exp[1] / logits_exp[1].sum()
+ expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])
+ expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])
+ loss_mean = np.average([expected_loss_0, expected_loss_1],
+ weights=weights)
+ expected_loss = loss_mean * np.sum(weights)
+
+ expected_metrics = {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ ops.GraphKeys.GLOBAL_STEP: 100,
+ metric_keys.MetricKeys.LOSS_MEAN: loss_mean,
+ metric_keys.MetricKeys.ACCURACY: 2. / (1. + 2.),
+ }
+
+ self.assertAllClose(sorted_key_dict(expected_metrics),
+ sorted_key_dict(eval_metrics), rtol=1e-3)
+
+ def test_binary_classes_evaluation_weights(self):
+ self._test_evaluation_weights(n_classes=2)
+
+ def test_multi_classes_evaluation_weights(self):
+ self._test_evaluation_weights(n_classes=4)
+
+
+class BaselineClassifierPredictTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ shutil.rmtree(self._model_dir)
+
+ def _testPredictions(self, n_classes, label_vocabulary, label_output_fn):
+ """Tests predict when all variables are one-dimensional."""
+ age = 1.
+
+ bias = [10.0] if n_classes == 2 else [10.0] * n_classes
+
+ with ops.Graph().as_default():
+ variables.Variable(bias, name=BIAS_NAME)
+ variables.Variable(100, name='global_step', dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ est = _baseline_classifier_fn(
+ label_vocabulary=label_vocabulary,
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'age': np.array([[age]])},
+ y=None,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False)
+ predictions = list(est.predict(input_fn=predict_input_fn))
+
+ if n_classes == 2:
+ scalar_logits = bias[0]
+ two_classes_logits = [0, scalar_logits]
+ two_classes_logits_exp = np.exp(two_classes_logits)
+ softmax = two_classes_logits_exp / two_classes_logits_exp.sum()
+
+ expected_predictions = {
+ 'class_ids': [1],
+ 'classes': [label_output_fn(1)],
+ 'logistic': [sigmoid(np.array(scalar_logits))],
+ 'logits': [scalar_logits],
+ 'probabilities': softmax,
+ }
+ else:
+ onedim_logits = np.array(bias)
+ class_ids = onedim_logits.argmax()
+ logits_exp = np.exp(onedim_logits)
+ softmax = logits_exp / logits_exp.sum()
+ expected_predictions = {
+ 'class_ids': [class_ids],
+ 'classes': [label_output_fn(class_ids)],
+ 'logits': onedim_logits,
+ 'probabilities': softmax,
+ }
+
+ self.assertEqual(1, len(predictions))
+ # assertAllClose cannot handle byte type.
+ self.assertEqual(expected_predictions['classes'], predictions[0]['classes'])
+ expected_predictions.pop('classes')
+ predictions[0].pop('classes')
+ self.assertAllClose(sorted_key_dict(expected_predictions),
+ sorted_key_dict(predictions[0]))
+
+ def testBinaryClassesWithoutLabelVocabulary(self):
+ n_classes = 2
+ self._testPredictions(n_classes,
+ label_vocabulary=None,
+ label_output_fn=lambda x: ('%s' % x).encode())
+
+ def testBinaryClassesWithLabelVocabulary(self):
+ n_classes = 2
+ self._testPredictions(
+ n_classes,
+ label_vocabulary=['class_vocab_{}'.format(i)
+ for i in range(n_classes)],
+ label_output_fn=lambda x: ('class_vocab_%s' % x).encode())
+
+ def testMultiClassesWithoutLabelVocabulary(self):
+ n_classes = 4
+ self._testPredictions(
+ n_classes,
+ label_vocabulary=None,
+ label_output_fn=lambda x: ('%s' % x).encode())
+
+ def testMultiClassesWithLabelVocabulary(self):
+ n_classes = 4
+ self._testPredictions(
+ n_classes,
+ label_vocabulary=['class_vocab_{}'.format(i)
+ for i in range(n_classes)],
+ label_output_fn=lambda x: ('class_vocab_%s' % x).encode())
+
+
+class BaselineClassifierIntegrationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ shutil.rmtree(self._model_dir)
+
+ def _test_complete_flow(self, n_classes, train_input_fn, eval_input_fn,
+ predict_input_fn, input_dimension, prediction_length):
+ feature_columns = [
+ feature_column_lib.numeric_column('x', shape=(input_dimension,))
+ ]
+ est = _baseline_classifier_fn(
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+
+ # TRAIN
+ # learn y = x
+ est.train(train_input_fn, steps=200)
+
+ # EVALUTE
+ scores = est.evaluate(eval_input_fn)
+ self.assertEqual(200, scores[ops.GraphKeys.GLOBAL_STEP])
+ self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))
+
+ # PREDICT
+ predictions = np.array(
+ [x['classes'] for x in est.predict(predict_input_fn)])
+ self.assertAllEqual((prediction_length, 1), predictions.shape)
+
+ # EXPORT
+ feature_spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+ serving_input_receiver_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ def _test_numpy_input_fn(self, n_classes):
+ """Tests complete flow with numpy_input_fn."""
+ input_dimension = 4
+ batch_size = 10
+ prediction_length = batch_size
+ data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, input_dimension)
+ target = np.array([1] * batch_size)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=target,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=target,
+ batch_size=batch_size,
+ num_epochs=1,
+ shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=None,
+ batch_size=batch_size,
+ num_epochs=1,
+ shuffle=False)
+
+ self._test_complete_flow(
+ n_classes=n_classes,
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=input_dimension,
+ prediction_length=prediction_length)
+
+ def test_binary_classes_numpy_input_fn(self):
+ self._test_numpy_input_fn(n_classes=2)
+
+ def test_multi_classes_numpy_input_fn(self):
+ self._test_numpy_input_fn(n_classes=4)
+
+ def _test_pandas_input_fn(self, n_classes):
+ """Tests complete flow with pandas_input_fn."""
+ if not HAS_PANDAS:
+ return
+
+ # Pandas DataFrame natually supports 1 dim data only.
+ input_dimension = 1
+ batch_size = 10
+ data = np.array([1., 2., 3., 4.], dtype=np.float32)
+ target = np.array([1, 0, 1, 0], dtype=np.int32)
+ x = pd.DataFrame({'x': data})
+ y = pd.Series(target)
+ prediction_length = 4
+
+ train_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)
+ eval_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, shuffle=False)
+ predict_input_fn = pandas_io.pandas_input_fn(
+ x=x, batch_size=batch_size, shuffle=False)
+
+ self._test_complete_flow(
+ n_classes=n_classes,
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=input_dimension,
+ prediction_length=prediction_length)
+
+ def test_binary_classes_pandas_input_fn(self):
+ self._test_pandas_input_fn(n_classes=2)
+
+ def test_multi_classes_pandas_input_fn(self):
+ self._test_pandas_input_fn(n_classes=4)
+
+ def _test_input_fn_from_parse_example(self, n_classes):
+ """Tests complete flow with input_fn constructed from parse_example."""
+ input_dimension = 2
+ batch_size = 10
+ prediction_length = batch_size
+ data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, input_dimension)
+ target = np.array([1] * batch_size, dtype=np.int64)
+
+ serialized_examples = []
+ for x, y in zip(data, target):
+ example = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'x':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=x)),
+ 'y':
+ feature_pb2.Feature(int64_list=feature_pb2.Int64List(
+ value=[y])),
+ }))
+ serialized_examples.append(example.SerializeToString())
+
+ feature_spec = {
+ 'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
+ 'y': parsing_ops.FixedLenFeature([1], dtypes.int64),
+ }
+
+ def _train_input_fn():
+ feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
+ features = queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _eval_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _predict_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = queue_parsed_features(feature_map)
+ features.pop('y')
+ return features, None
+
+ self._test_complete_flow(
+ n_classes=n_classes,
+ train_input_fn=_train_input_fn,
+ eval_input_fn=_eval_input_fn,
+ predict_input_fn=_predict_input_fn,
+ input_dimension=input_dimension,
+ prediction_length=prediction_length)
+
+ def test_binary_classes_input_fn_from_parse_example(self):
+ self._test_input_fn_from_parse_example(n_classes=2)
+
+ def test_multi_classes_input_fn_from_parse_example(self):
+ self._test_input_fn_from_parse_example(n_classes=4)
+
+
+# Tests for Baseline logit_fn.
+
+
+class BaselineLogitFnTest(test.TestCase):
+
+ def test_basic_logit_correctness(self):
+ """baseline_logit_fn simply returns the bias variable."""
+ with ops.Graph().as_default():
+ logit_fn = baseline._baseline_logit_fn_builder(num_outputs=2)
+ logits = logit_fn(features={'age': [[23.], [31.]]})
+ with variable_scope.variable_scope('baseline', reuse=True):
+ bias_var = variable_scope.get_variable('bias')
+ with tf_session.Session() as sess:
+ sess.run([variables.global_variables_initializer()])
+ self.assertAllClose([[0., 0.], [0., 0.]], logits.eval())
+ sess.run(bias_var.assign([10., 5.]))
+ self.assertAllClose([[10., 5.], [10., 5.]], logits.eval())
+
+
+if __name__ == '__main__':
+ test.main()
+
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index 01c00621ce..d13ecd13a1 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -264,26 +264,55 @@ def _check_dense_labels_match_logits_and_reshape(
return array_ops.identity(labels, name=scope)
-def _check_weights_match_logits_and_reshape(weights, logits):
- """Checks that weights shape matches logits and reshapes if needed.
+def _get_weights_and_check_match_logits(
+ features, weight_column, logits, allow_per_logit_weights=False):
+ """Fetches weights from features and checks that the shape matches logits.
Consider logits of shape [D0, D1, ... DN, logits_dimension]. Weights shape
can be either:
- * [D0, D1, ... DN, logits_dimension]
+ * [D0, D1, ... DN, logits_dimension] if `allow_per_logit_weights=True`.
* [D0, D1, ... DN, 1]
* [D0, D1, ... DN]: In this case, weights is reshaped into
[D0, D1, ... DN, 1] to work with weight broadcasting rules.
Args:
- weights: weights Tensor.
+ features: The features dict that contains weights.
+ weight_column: The weight column. If not given, this method returns 1.
logits: logits Tensor.
+ allow_per_logit_weights: Boolean. Whether we allow weights along the logits
+ dimension, namely shape `[D0, D1, ... DN, logits_dimension]`.
Returns:
Validated and reshaped weights Tensor.
+ Raises:
+ ValueError: If the weights `Tensor` cannot be cast into float.
"""
- err_msg = (
- 'weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or '
- '[D0, D1, ... DN, logits_dimension]')
- with ops.name_scope(None, 'weights', (weights, logits)) as scope:
+ if allow_per_logit_weights:
+ err_msg = (
+ 'weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or '
+ '[D0, D1, ... DN, logits_dimension]')
+ else:
+ err_msg = (
+ 'weights shape must be [D0, D1, ... DN] or [D0, D1, ... DN, 1]')
+ with ops.name_scope(
+ None, 'weights',
+ values=tuple(six.itervalues(features)) + (logits,)) as scope:
+ # Fetch the weights.
+ if weight_column is None:
+ return 1.
+ if isinstance(weight_column, six.string_types):
+ weight_column = feature_column_lib.numeric_column(
+ key=weight_column, shape=(1,))
+ if not isinstance(weight_column, feature_column_lib._NumericColumn): # pylint: disable=protected-access
+ raise TypeError('Weight column must be either a string or _NumericColumn.'
+ ' Given type: {}.'.format(type(weight_column)))
+ weights = weight_column._get_dense_tensor( # pylint: disable=protected-access
+ feature_column_lib._LazyBuilder(features)) # pylint: disable=protected-access
+ if not (weights.dtype.is_floating or weights.dtype.is_integer):
+ raise ValueError('Weight column should be castable to float. '
+ 'Given dtype: {}'.format(weights.dtype))
+ weights = math_ops.to_float(weights, name='weights')
+
+ # Validate the weights shape.
weights_shape = array_ops.shape(weights, name='weights_shape')
logits_shape = array_ops.shape(logits, name='logits_shape')
if (weights.shape.ndims is not None and logits.shape.ndims is not None and
@@ -295,42 +324,24 @@ def _check_weights_match_logits_and_reshape(weights, logits):
with ops.control_dependencies([assert_dimension]):
return array_ops.expand_dims(weights, -1, name=scope)
supported_weights_shape = array_ops.concat([logits_shape[:-1], [1]], axis=0)
- condition = math_ops.reduce_any(
- [math_ops.reduce_all(math_ops.equal(logits_shape, weights_shape)),
- math_ops.reduce_all(math_ops.equal(
- supported_weights_shape, weights_shape))])
- assert_dimension = control_flow_ops.Assert(
- condition=condition,
- data=[err_msg, 'logits_shape: ', logits_shape,
- 'weights_shape: ', weights_shape])
+ if allow_per_logit_weights:
+ condition = math_ops.reduce_any(
+ [math_ops.reduce_all(math_ops.equal(logits_shape, weights_shape)),
+ math_ops.reduce_all(math_ops.equal(
+ supported_weights_shape, weights_shape))])
+ assert_dimension = control_flow_ops.Assert(
+ condition=condition,
+ data=[err_msg, 'logits_shape: ', logits_shape,
+ 'weights_shape: ', weights_shape])
+ else:
+ assert_dimension = check_ops.assert_equal(
+ supported_weights_shape, weights_shape, message=err_msg,
+ data=['logits_shape: ', logits_shape,
+ 'weights_shape: ', weights_shape])
with ops.control_dependencies([assert_dimension]):
return array_ops.identity(weights, name=scope)
-# TODO(roumposg): Delete once all heads support multi-dim input.
-def _check_logits(logits, expected_logits_dimension):
- """Check logits type and shape."""
- with ops.name_scope(None, 'logits', (logits,)) as scope:
- logits = math_ops.to_float(logits)
- logits_shape = array_ops.shape(logits)
- assert_rank = check_ops.assert_rank(
- logits, 2, data=[logits_shape],
- message='logits shape must be [batch_size, logits_dimension]')
- with ops.control_dependencies([assert_rank]):
- static_shape = logits.shape
- if static_shape is not None:
- dim1 = static_shape[1]
- if (dim1 is not None) and (dim1 != expected_logits_dimension):
- raise ValueError(
- 'logits shape must be [batch_size, logits_dimension], got %s.' %
- (static_shape,))
- assert_dimension = check_ops.assert_equal(
- expected_logits_dimension, logits_shape[1], data=[logits_shape],
- message='logits shape must be [batch_size, logits_dimension]')
- with ops.control_dependencies([assert_dimension]):
- return array_ops.identity(logits, name=scope)
-
-
def _check_logits_final_dim(logits, expected_logits_dimension):
"""Checks that logits shape is [D0, D1, ... DN, logits_dimension]."""
with ops.name_scope(None, 'logits', (logits,)) as scope:
@@ -575,10 +586,8 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
# Restore the squeezed dim, so unweighted_loss matches the weights shape.
unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=-1)
- weights = _weights(features, self._weight_column)
- if self._weight_column is not None:
- weights = _check_weights_match_logits_and_reshape(
- weights=weights, logits=logits)
+ weights = _get_weights_and_check_match_logits(
+ features=features, weight_column=self._weight_column, logits=logits)
weighted_sum_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
# _weights() can return 1.
@@ -680,7 +689,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
weight_column=None, thresholds=None, label_vocabulary=None, name=None):
- """Creates a `Head` for single label binary classification.
+ """Creates a `_Head` for single label binary classification.
This head uses `sigmoid_cross_entropy_with_logits` loss.
@@ -718,7 +727,7 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
Returns:
- An instance of `Head` for binary classification.
+ An instance of `_Head` for binary classification.
Raises:
ValueError: if `thresholds` contains a value outside of `(0, 1)`.
@@ -852,10 +861,8 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
labels = _assert_range(labels, 2)
unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
labels=labels, logits=logits)
- weights = _weights(features, self._weight_column)
- if self._weight_column is not None:
- weights = _check_weights_match_logits_and_reshape(
- weights=weights, logits=logits)
+ weights = _get_weights_and_check_match_logits(
+ features=features, weight_column=self._weight_column, logits=logits)
weighted_sum_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
# _weights() can return 1.
@@ -918,12 +925,8 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
# Eval.
if mode == model_fn.ModeKeys.EVAL:
- weights = _weights(features, self._weight_column)
- # TODO(roumposg): Merge this logic inside _weights once all heads
- # support multi-dimensional inputs.
- if self._weight_column is not None:
- weights = _check_weights_match_logits_and_reshape(
- weights=weights, logits=logits)
+ weights = _get_weights_and_check_match_logits(
+ features=features, weight_column=self._weight_column, logits=logits)
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
@@ -957,7 +960,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
def _regression_head_with_mean_squared_error_loss(weight_column=None,
label_dimension=1,
name=None):
- """Creates a `_Head` for regression using the mean squared loss.
+ """Creates a `_Head` for regression using the `mean_squared_error` loss.
The loss is the weighted sum over all input dimensions. Namely, if the input
labels have shape `[batch_size, label_dimension]`, the loss is the weighted
@@ -1023,10 +1026,9 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
labels = math_ops.to_float(labels)
unweighted_loss = losses.mean_squared_error(
labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
- weights = _weights(features, self._weight_column)
- if self._weight_column is not None:
- weights = _check_weights_match_logits_and_reshape(
- weights=weights, logits=logits)
+ weights = _get_weights_and_check_match_logits(
+ features=features, weight_column=self._weight_column, logits=logits,
+ allow_per_logit_weights=True)
weighted_sum_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
# _weights() can return 1.
@@ -1111,18 +1113,19 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
train_op=train_op_fn(weighted_sum_loss))
-def _assert_range(labels, n_classes):
+def _assert_range(labels, n_classes, message=None):
with ops.name_scope(None, 'assert_range', (labels,)):
assert_less = check_ops.assert_less(
labels,
ops.convert_to_tensor(n_classes, dtype=labels.dtype),
- message='Label IDs must < n_classes')
+ message=message or 'Label IDs must < n_classes')
assert_greater = check_ops.assert_non_negative(
- labels, message='Label IDs must >= 0')
+ labels, message=message or 'Label IDs must >= 0')
with ops.control_dependencies((assert_less, assert_greater)):
return array_ops.identity(labels)
+# TODO(b/69000400): Delete this method.
def _weights(features, weight_column):
"""Fetches weights from features."""
with ops.name_scope(None, 'weights', values=features.values()):
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index 0a4ea7d81c..4497cd26f2 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -987,12 +987,14 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
spec.loss.eval()
def test_multi_dim_train_weights_wrong_outer_dim(self):
- """Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2, 2]."""
+ """Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2, 3]."""
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
n_classes=3, weight_column='weights')
logits = np.array([[[10, 0, 0], [12, 0, 0]],
[[0, 10, 0], [0, 15, 0]]], dtype=np.float32)
labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64)
+ weights = np.array([[[1., 1.1, 1.2], [1.5, 1.6, 1.7]],
+ [[2., 2.1, 2.2], [2.5, 2.6, 2.7]]])
weights_placeholder = array_ops.placeholder(dtype=dtypes.float32)
def _no_op_train_fn(loss):
del loss
@@ -1008,10 +1010,8 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
- r'\[logits_shape: \]\s\[2 2 3\]\s\[weights_shape: \]\s\[2 2 2\]'):
- spec.loss.eval({
- weights_placeholder: np.array([[[1., 1.1], [1.5, 1.6]],
- [[2., 2.1], [2.5, 2.6]]])})
+ r'\[logits_shape: \]\s\[2 2 3\]\s\[weights_shape: \]\s\[2 2 3\]'):
+ spec.loss.eval({weights_placeholder: weights})
def test_multi_dim_weighted_eval(self):
"""Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2]."""
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index a730e107ba..2d036e2cfb 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -537,7 +537,7 @@ class Estimator(object):
temp_export_dir = get_temp_export_dir(export_dir)
# TODO(soergel): Consider whether MonitoredSession makes sense here
- with tf_session.Session() as session:
+ with tf_session.Session(config=self._session_config) as session:
saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
sharded=True)
diff --git a/tensorflow/python/estimator/estimator_lib.py b/tensorflow/python/estimator/estimator_lib.py
index 5b82fd75ff..bed2b67419 100644
--- a/tensorflow/python/estimator/estimator_lib.py
+++ b/tensorflow/python/estimator/estimator_lib.py
@@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.python.estimator.canned.baseline import BaselineClassifier
+from tensorflow.python.estimator.canned.baseline import BaselineRegressor
from tensorflow.python.estimator.canned.dnn import DNNClassifier
from tensorflow.python.estimator.canned.dnn import DNNRegressor
from tensorflow.python.estimator.canned.dnn_linear_combined import DNNLinearCombinedClassifier
@@ -46,6 +48,8 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
# Canned Estimators
+ 'BaselineClassifier',
+ 'BaselineRegressor',
'DNNClassifier',
'DNNRegressor',
'DNNLinearCombinedClassifier',
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 2b9b44523b..c1b773b8c4 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -50,6 +50,7 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
@@ -1910,6 +1911,71 @@ class EstimatorExportTest(test.TestCase):
est.train(dummy_input_fn, steps=1)
est.export_savedmodel(tempfile.mkdtemp(), serving_input_receiver_fn)
+ def test_export_savedmodel_respects_soft_placement(self):
+ def model_fn_with_a_gpu_op_but_no_kernel(features, labels, mode):
+ _, _ = features, labels
+ table = saver_test_utils.CheckpointedOp(name='v2')
+
+ update_global_step = state_ops.assign_add(training.get_global_step(), 1)
+ with ops.control_dependencies([update_global_step]):
+ train_op = table.insert('k1', 30.0)
+
+ # In this test, there are no GPUs available. The goal is to verify that
+ # export_savedmodel executes nevertheless.
+ with ops.device('/gpu:0'):
+ string_op = string_ops.as_string(update_global_step)
+
+ with ops.control_dependencies([string_op]):
+ prediction = table.lookup('k1', 0.0)
+
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=prediction,
+ loss=constant_op.constant(1.),
+ train_op=train_op,
+ export_outputs={
+ 'test': export_output.PredictOutput({
+ 'prediction': prediction
+ })
+ })
+
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(
+ model_fn=model_fn_with_a_gpu_op_but_no_kernel)
+ est.train(input_fn=dummy_input_fn, steps=1)
+ feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
+ 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+
+ export_dir = est.export_savedmodel(
+ export_dir_base, serving_input_receiver_fn)
+
+ # At this point, if export_savedmodel executed with
+ # allow_soft_placement=True, then the GPU-assigned operation was silently
+ # placed on the CPU. Otherwise, an exception would have been raised
+ # related to the fact that the requested GPU device isn't available.
+
+ # Expectations below assume that export_savedmodel has completed normally.
+ self.assertTrue(gfile.Exists(export_dir_base))
+ self.assertTrue(gfile.Exists(export_dir))
+ self.assertTrue(gfile.Exists(os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('saved_model.pb'))))
+ self.assertTrue(gfile.Exists(os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables'))))
+ self.assertTrue(gfile.Exists(os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables/variables.index'))))
+ self.assertTrue(gfile.Exists(os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables/variables.data-00000-of-00001'))))
+
+ gfile.DeleteRecursively(tmpdir)
+
class EstimatorHookOrderingTest(test.TestCase):
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index cef3f8d4c4..29cf223724 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -100,7 +100,7 @@ class Defun(object):
grad_func - (optional). A function implementing the gradient
of the function-to-register. This is must be a
`_DefinedFunction` object. The gradient
- function must satisify the criterion defined in
+ function must satisfy the criterion defined in
function.proto:GradientDef.
python_grad_func - (optional). A function implementing the
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 36b0737cfc..ba43e9199b 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -370,7 +370,7 @@ class FunctionTest(test.TestCase):
@function.Defun(dtypes.float32)
def Foo(x):
- y = logging_ops.Print(x, [x], "Hello")
+ y = logging_ops.Print(x, [], "Hello")
with ops.control_dependencies([y]):
z = control_flow_ops.no_op()
with ops.control_dependencies([z]):
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index ab4455534e..167a6c681e 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -617,15 +617,16 @@ class _EagerTensorBase(Tensor):
return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access
def numpy(self):
- """Returns a numpy array with the same contents as the Tensor.
+ """Returns a numpy array or a scalar with the same contents as the Tensor.
TODO(ashankar,agarwal): Perhaps this should NOT reference the underlying
buffer but instead always explicitly copy? Note that currently it may or may
not copy based on whether the numpy data is properly aligned or not.
Returns:
- A numpy array that may share memory with the Tensor object. Any changes
- to one may be reflected in the other.
+ A numpy array or a scalar. Numpy array may share memory with the
+ Tensor object. Any changes to one may be reflected in the other. A scalar
+ value is returned when self has rank 0.
Raises:
ValueError: if the type of this Tensor is not representable in numpy.
@@ -1642,15 +1643,17 @@ class Operation(object):
def colocation_groups(self):
"""Returns the list of colocation groups of the op."""
default_colocation_group = [
- compat.as_bytes("loc:@%s" % self._node_def.name)
+ compat.as_bytes("loc:@%s" % self.name)
]
- if "_class" not in self._node_def.attr:
+ try:
+ class_attr = self.get_attr("_class")
+ except ValueError:
# This op has no explicit colocation group, so it is itself its
# own root of a colocation group.
return default_colocation_group
attr_groups = [
- class_name for class_name in self.get_attr("_class")
+ class_name for class_name in class_attr
if class_name.startswith(b"loc:@")
]
@@ -1895,7 +1898,7 @@ class Operation(object):
["^%s" % op.name for op in self._control_inputs])
def __str__(self):
- return str(self._node_def)
+ return str(self.node_def)
def __repr__(self):
return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
@@ -2012,7 +2015,7 @@ class Operation(object):
@property
def node_def(self):
# pylint: disable=line-too-long
- """Returns a serialized `NodeDef` representation of this operation.
+ """Returns the `NodeDef` representation of this operation.
Returns:
A
@@ -2020,7 +2023,16 @@ class Operation(object):
protocol buffer.
"""
# pylint: enable=line-too-long
- return self._node_def
+ if self._c_op:
+ with c_api_util.tf_buffer() as buf:
+ with errors.raise_exception_on_not_ok_status() as status:
+ c_api.TF_OperationToNodeDef(self._c_op, buf, status)
+ data = c_api.TF_GetBuffer(buf)
+ node_def = node_def_pb2.NodeDef()
+ node_def.ParseFromString(compat.as_bytes(data))
+ return node_def
+ else:
+ return self._node_def
@property
def op_def(self):
@@ -2034,13 +2046,13 @@ class Operation(object):
"""
# pylint: enable=line-too-long
if self._c_op:
- with errors.raise_exception_on_not_ok_status() as status:
- with c_api_util.tf_buffer() as buf:
+ with c_api_util.tf_buffer() as buf:
+ with errors.raise_exception_on_not_ok_status() as status:
# pylint: disable=protected-access
c_api.TF_GraphGetOpDef(self._graph._c_graph,
compat.as_bytes(self.type), buf, status)
# pylint: enable=protected-access
- data = c_api.TF_GetBuffer(buf)
+ data = c_api.TF_GetBuffer(buf)
op_def = op_def_pb2.OpDef()
op_def.ParseFromString(compat.as_bytes(data))
return op_def
@@ -2065,16 +2077,19 @@ class Operation(object):
def _set_attr(self, attr_name, attr_value):
"""Private method used to set an attribute in the node_def."""
- if not _USE_C_API:
- assert "_set_attr not supported with _USE_C_API == False"
- return
- buf = c_api.TF_NewBufferFromString(
- compat.as_bytes(attr_value.SerializeToString()))
- try:
- with errors.raise_exception_on_not_ok_status() as status:
- c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf, status) # pylint: disable=protected-access
- finally:
- c_api.TF_DeleteBuffer(buf)
+ if _USE_C_API:
+ buf = c_api.TF_NewBufferFromString(
+ compat.as_bytes(attr_value.SerializeToString()))
+ try:
+ with errors.raise_exception_on_not_ok_status() as status:
+ # pylint: disable=protected-access
+ c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf,
+ status)
+ # pylint: enable=protected-access
+ finally:
+ c_api.TF_DeleteBuffer(buf)
+ else:
+ self._node_def.attr[attr_name].CopyFrom(attr_value)
def get_attr(self, name):
"""Returns the value of the attr of this op with the given `name`.
@@ -2088,25 +2103,24 @@ class Operation(object):
Raises:
ValueError: If this op does not have an attr with the given `name`.
"""
- if _USE_C_API:
+ fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
+ if self._c_op:
try:
- # TODO(b/65162920): remove this try/except block when all attrs are
- # implemented to use the _set_attr method instead of node_def.attr.
- with errors.raise_exception_on_not_ok_status() as status:
- metadata = c_api.TF_OperationGetAttrMetadata(self._c_op, name, status)
- with errors.raise_exception_on_not_ok_status() as status:
- if metadata.type == c_api.TF_ATTR_INT and metadata.is_list == 0:
- return c_api.TF_OperationGetAttrInt(self._c_op, name, status)
- except errors.InvalidArgumentError:
- # Colocation ops are failing to find attrs begininning with "_*". They
- # should fall through to the not-CAPI logic until the attribute is set
- # via the C-API always.
- pass
+ with c_api_util.tf_buffer() as buf:
+ with errors.raise_exception_on_not_ok_status() as status:
+ c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf, status)
+ data = c_api.TF_GetBuffer(buf)
+ except errors.InvalidArgumentError as e:
+ # Convert to ValueError for backwards compatibility.
+ raise ValueError(str(e))
+ x = attr_value_pb2.AttrValue()
+ x.ParseFromString(data)
+ else:
+ if name not in self._node_def.attr:
+ raise ValueError(
+ "No attr named '" + name + "' in " + str(self._node_def))
+ x = self._node_def.attr[name]
- fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
- if name not in self._node_def.attr:
- raise ValueError("No attr named '" + name + "' in " + str(self._node_def))
- x = self._node_def.attr[name]
# Treat an empty oneof value as an empty list.
if not x.WhichOneof("value"):
return []
@@ -2749,10 +2763,10 @@ class Graph(object):
"""
# pylint: enable=line-too-long
if self._c_graph:
- with errors.raise_exception_on_not_ok_status() as status:
- with c_api_util.tf_buffer() as buf:
+ with c_api_util.tf_buffer() as buf:
+ with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_GraphVersions(self._c_graph, buf, status)
- data = c_api.TF_GetBuffer(buf)
+ data = c_api.TF_GetBuffer(buf)
version_def = versions_pb2.VersionDef()
version_def.ParseFromString(compat.as_bytes(data))
return version_def
@@ -3106,9 +3120,10 @@ class Graph(object):
ret._set_device(colocation_op.device) # pylint: disable=protected-access
all_colocation_groups = sorted(set(all_colocation_groups))
- ret.node_def.attr["_class"].CopyFrom(
- attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue(
- s=all_colocation_groups)))
+ # pylint: disable=protected-access
+ ret._set_attr("_class", attr_value_pb2.AttrValue(
+ list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
+ # pylint: enable=protected-access
# Sets "container" attribute if
# (1) self._container is not None
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 3087d6060b..4e931e00c5 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -31,9 +31,11 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
@@ -357,54 +359,55 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
def testGetAttr(self):
- # TODO(b/65162920): implement all tests for get_attr with C API
+ op = test_ops.default_attrs()
+ self.assertEqual(op.get_attr("string_val"), b"abc")
+ self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""])
+ self.assertEqual(op.get_attr("int_val"), 123)
+ self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3])
+ self.assertEqual(op.get_attr("float_val"), 10.0)
+ self.assertEqual(op.get_attr("float_list_val"), [10.0])
+ self.assertEqual(op.get_attr("bool_val"), True)
+ self.assertEqual(op.get_attr("bool_list_val"), [True, False])
+ self.assertEqual(op.get_attr("shape_val"),
+ tensor_shape.as_shape([2, 1]).as_proto())
+ self.assertEqual(op.get_attr("shape_list_val"),
+ [tensor_shape.as_shape([]).as_proto(),
+ tensor_shape.as_shape([1]).as_proto()])
+ self.assertEqual(op.get_attr("tensor_val"),
+ tensor_util.make_tensor_proto(1, dtypes.int32))
+ self.assertEqual(op.get_attr("tensor_list_val"),
+ [tensor_util.make_tensor_proto(1, dtypes.int32)])
+
+ type_val = op.get_attr("type_val")
+ # First check that type_val is a DType, because the assertEquals will work
+ # no matter what since DType overrides __eq__
+ self.assertIsInstance(type_val, dtypes.DType)
+ self.assertEqual(type_val, dtypes.int32)
+
+ type_list_val = op.get_attr("type_list_val")
+ self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val))
+ self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32])
+
+ @function.Defun(dtypes.float32, func_name="MyFunc")
+ def func(x):
+ return x
+
+ op = test_ops.func_attr(func)
+ self.assertEqual(op.get_attr("f"),
+ attr_value_pb2.NameAttrList(name="MyFunc"))
+
+ # Try fetching missing attr
if ops._USE_C_API:
- op = test_ops.int_attr().op
- self.assertEqual(op.get_attr("foo"), 1)
-
- op_str = test_ops.string_list_attr(a=["z"], b="y")
- self.assertEqual(op_str.get_attr("a"), [b"z"])
- self.assertEqual(op_str.get_attr("b"), b"y")
-
+ error_msg = "Operation 'FuncAttr' has no attr named 'FakeAttr'."
else:
- list_value = attr_value_pb2.AttrValue.ListValue()
-
- list_value.type.append(types_pb2.DT_STRING)
- list_value.type.append(types_pb2.DT_DOUBLE)
- op = ops.Operation(
- ops._NodeDef(
- "None",
- "op1",
- attrs={
- "value":
- attr_value_pb2.AttrValue(i=32),
- "dtype":
- attr_value_pb2.AttrValue(type=types_pb2.DT_INT32),
- "list":
- attr_value_pb2.AttrValue(list=list_value),
- "func":
- attr_value_pb2.AttrValue(
- func=attr_value_pb2.NameAttrList())
- }), ops.Graph(), [], [dtypes.int32])
- self.assertEqual(32, op.get_attr("value"))
- self.assertEqual("", op.get_attr("func").name)
-
- d = op.get_attr("dtype")
- # First check that d is a DType, because the assertEquals will
- # work no matter what since DType overrides __eq__
- self.assertIsInstance(d, dtypes.DType)
- self.assertEqual(dtypes.int32, d)
-
- l = op.get_attr("list")
- for x in l:
- self.assertIsInstance(x, dtypes.DType)
- self.assertEqual([dtypes.string, dtypes.double], l)
+ error_msg = "No attr named 'FakeAttr' in name: \"FuncAttr\""
+
+ with self.assertRaisesRegexp(ValueError, error_msg):
+ op.get_attr("FakeAttr")
# TODO(b/65162920): remove this test when users who are directly mutating the
# node_def have been updated to proper usage.
def testSetAttr(self):
- if not ops._USE_C_API:
- return
op = test_ops.int_attr().op
op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
# TODO(skyewm): add node_def check
diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc
index a8b7fc543f..35e0167b26 100644
--- a/tensorflow/python/framework/test_ops.cc
+++ b/tensorflow/python/framework/test_ops.cc
@@ -341,4 +341,27 @@ REGISTER_OP("StringListAttr")
.Attr("b: string")
.SetShapeFn(shape_inference::UnknownShape);
+REGISTER_OP("DefaultAttrs")
+ .Attr("string_val: string = 'abc'")
+ .Attr("string_list_val: list(string) = ['abc', '']")
+ .Attr("int_val: int = 123")
+ .Attr("int_list_val: list(int) = [1, 2, 3]")
+ .Attr("float_val: float = 10.0")
+ .Attr("float_list_val: list(float) = [10.0]")
+ .Attr("bool_val: bool = true")
+ .Attr("bool_list_val: list(bool) = [true, false]")
+ .Attr("type_val: type = DT_INT32")
+ .Attr("type_list_val: list(type) = [DT_INT32, DT_FLOAT]")
+ .Attr("shape_val: shape = { dim { size: 2 } dim { size: 1 } }")
+ .Attr("shape_list_val: list(shape) = [{}, { dim { size: 1} }]")
+ .Attr("tensor_val: tensor = { dtype: DT_INT32 tensor_shape: {} int_val: 1}")
+ .Attr(
+ "tensor_list_val: list(tensor) = "
+ "[{ dtype: DT_INT32 tensor_shape: {} int_val: 1}]")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("FuncAttr")
+ .Attr("f: func")
+ .SetShapeFn(shape_inference::UnknownShape);
+
} // end namespace tensorflow
diff --git a/tensorflow/python/grappler/model_analyzer.cc b/tensorflow/python/grappler/model_analyzer.cc
index 4ec7620bce..7d365c3be9 100644
--- a/tensorflow/python/grappler/model_analyzer.cc
+++ b/tensorflow/python/grappler/model_analyzer.cc
@@ -59,10 +59,15 @@ void ModelAnalyzer::PrintNodeInfo(const NodeDef* node,
if (i > 0) {
os << ", ";
}
- if (prop.shape().dim(i).size() < 0) {
+ if (prop.shape().dim(i).size() >= 0) {
+ // Print the actual dimension.
+ os << prop.shape().dim(i).size();
+ } else if (prop.shape().dim(i).size() == -1) {
+ // We don't know anything about the dimension.
os << "?";
} else {
- os << prop.shape().dim(i).size();
+ // Symbolic dimension.
+ os << "x" << -prop.shape().dim(i).size();
}
}
os << "]";
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 4db48b45ed..6a762ee5d2 100644
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -499,6 +499,18 @@ py_test(
)
py_test(
+ name = "recurrent_test",
+ size = "small",
+ srcs = ["_impl/keras/layers/recurrent_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":keras",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "serialization_test",
size = "small",
srcs = ["_impl/keras/layers/serialization_test.py"],
diff --git a/tensorflow/python/keras/_impl/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/topology.py
index f9be782f85..2bcbabf19c 100644
--- a/tensorflow/python/keras/_impl/keras/engine/topology.py
+++ b/tensorflow/python/keras/_impl/keras/engine/topology.py
@@ -29,6 +29,9 @@ from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import constraints
+from tensorflow.python.keras._impl.keras import initializers
+from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.utils import conv_utils
from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary as print_layer_summary
@@ -209,9 +212,9 @@ class Layer(tf_base_layers.Layer):
dtype = K.floatx()
weight = self.add_variable(name, shape,
dtype=dtype,
- initializer=initializer,
- regularizer=regularizer,
- constraint=constraint,
+ initializer=initializers.get(initializer),
+ regularizer=regularizers.get(regularizer),
+ constraint=constraints.get(constraint),
trainable=trainable)
return weight
diff --git a/tensorflow/python/keras/_impl/keras/integration_test.py b/tensorflow/python/keras/_impl/keras/integration_test.py
index 7110036848..871a8c7329 100644
--- a/tensorflow/python/keras/_impl/keras/integration_test.py
+++ b/tensorflow/python/keras/_impl/keras/integration_test.py
@@ -93,7 +93,7 @@ class KerasIntegrationTest(test.TestCase):
y_test = keras.utils.to_categorical(y_test)
model = keras.models.Sequential()
- model.add(keras.layers.LSTM(3, return_sequences=True,
+ model.add(keras.layers.LSTM(5, return_sequences=True,
input_shape=x_train.shape[1:]))
model.add(keras.layers.GRU(y_train.shape[-1], activation='softmax'))
model.compile(loss='categorical_crossentropy',
diff --git a/tensorflow/python/keras/_impl/keras/layers/gru_test.py b/tensorflow/python/keras/_impl/keras/layers/gru_test.py
index 03f0736161..c57fbac41c 100644
--- a/tensorflow/python/keras/_impl/keras/layers/gru_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/gru_test.py
@@ -156,8 +156,10 @@ class GRULayerTest(test.TestCase):
activity_regularizer='l1')
layer.build((None, None, 2))
self.assertEqual(len(layer.losses), 3)
- layer(keras.backend.variable(np.ones((2, 3, 2))))
- self.assertEqual(len(layer.losses), 4)
+
+ x = keras.backend.variable(np.ones((2, 3, 2)))
+ layer(x)
+ self.assertEqual(len(layer.get_losses_for(x)), 1)
def test_constraints_GRU(self):
embedding_dim = 4
@@ -175,9 +177,9 @@ class GRULayerTest(test.TestCase):
recurrent_constraint=r_constraint,
bias_constraint=b_constraint)
layer.build((None, None, embedding_dim))
- self.assertEqual(layer.kernel.constraint, k_constraint)
- self.assertEqual(layer.recurrent_kernel.constraint, r_constraint)
- self.assertEqual(layer.bias.constraint, b_constraint)
+ self.assertEqual(layer.cell.kernel.constraint, k_constraint)
+ self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint)
+ self.assertEqual(layer.cell.bias.constraint, b_constraint)
def test_with_masking_layer_GRU(self):
layer_class = keras.layers.GRU
diff --git a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py
index f43d90fec8..8d359bf17c 100644
--- a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py
@@ -156,8 +156,9 @@ class LSTMLayerTest(test.TestCase):
activity_regularizer='l1')
layer.build((None, None, 2))
self.assertEqual(len(layer.losses), 3)
- layer(keras.backend.variable(np.ones((2, 3, 2))))
- self.assertEqual(len(layer.losses), 4)
+ x = keras.backend.variable(np.ones((2, 3, 2)))
+ layer(x)
+ self.assertEqual(len(layer.get_losses_for(x)), 1)
def test_constraints_LSTM(self):
embedding_dim = 4
@@ -175,9 +176,9 @@ class LSTMLayerTest(test.TestCase):
recurrent_constraint=r_constraint,
bias_constraint=b_constraint)
layer.build((None, None, embedding_dim))
- self.assertEqual(layer.kernel.constraint, k_constraint)
- self.assertEqual(layer.recurrent_kernel.constraint, r_constraint)
- self.assertEqual(layer.bias.constraint, b_constraint)
+ self.assertEqual(layer.cell.kernel.constraint, k_constraint)
+ self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint)
+ self.assertEqual(layer.cell.bias.constraint, b_constraint)
def test_with_masking_layer_LSTM(self):
layer_class = keras.layers.LSTM
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
index 139523403c..2bc74d5f80 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,99 +29,209 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
+from tensorflow.python.platform import tf_logging as logging
-# pylint: disable=access-member-before-definition
+class StackedRNNCells(Layer):
+ """Wrapper allowing a stack of RNN cells to behave as a single cell.
-
-def _time_distributed_dense(x,
- w,
- b=None,
- dropout=None,
- input_dim=None,
- output_dim=None,
- timesteps=None,
- training=None):
- """Apply `y . w + b` for every temporal slice y of x.
+ Used to implement efficient stacked RNNs.
Arguments:
- x: input tensor.
- w: weight matrix.
- b: optional bias vector.
- dropout: whether to apply dropout (same dropout mask
- for every temporal slice of the input).
- input_dim: integer; optional dimensionality of the input.
- output_dim: integer; optional dimensionality of the output.
- timesteps: integer; optional number of timesteps.
- training: training phase tensor or boolean.
-
- Returns:
- Output tensor.
- """
- if not input_dim:
- input_dim = K.shape(x)[2]
- if not timesteps:
- timesteps = K.shape(x)[1]
- if not output_dim:
- output_dim = K.shape(w)[1]
-
- if dropout is not None and 0. < dropout < 1.:
- # apply the same dropout pattern at every timestep
- ones = K.ones_like(K.reshape(x[:, 0, :], (-1, input_dim)))
- dropout_matrix = K.dropout(ones, dropout)
- expanded_dropout_matrix = K.repeat(dropout_matrix, timesteps)
- x = K.in_train_phase(x * expanded_dropout_matrix, x, training=training)
-
- # collapse time dimension and batch dimension together
- x = K.reshape(x, (-1, input_dim))
- x = K.dot(x, w)
- if b is not None:
- x = K.bias_add(x, b)
- # reshape to 3D tensor
- if K.backend() == 'tensorflow':
- x = K.reshape(x, K.stack([-1, timesteps, output_dim]))
- x.set_shape([None, None, output_dim])
- else:
- x = K.reshape(x, (-1, timesteps, output_dim))
- return x
+ cells: List of RNN cell instances.
+ Examples:
-class Recurrent(Layer):
- """Abstract base class for recurrent layers.
+ ```python
+ cells = [
+ keras.layers.LSTMCell(output_dim),
+ keras.layers.LSTMCell(output_dim),
+ keras.layers.LSTMCell(output_dim),
+ ]
- Do not use in a model -- it's not a valid layer!
- Use its children classes `LSTM`, `GRU` and `SimpleRNN` instead.
+ inputs = keras.Input((timesteps, input_dim))
+ x = keras.layers.RNN(cells)(inputs)
+ ```
+ """
- All recurrent layers (`LSTM`, `GRU`, `SimpleRNN`) also
- follow the specifications of this class and accept
- the keyword arguments listed below.
+ def __init__(self, cells, **kwargs):
+ for cell in cells:
+ if not hasattr(cell, 'call'):
+ raise ValueError('All cells must have a `call` method. '
+ 'received cells:', cells)
+ if not hasattr(cell, 'state_size'):
+ raise ValueError('All cells must have a '
+ '`state_size` attribute. '
+ 'received cells:', cells)
+ self.cells = cells
+ super(StackedRNNCells, self).__init__(**kwargs)
+
+ @property
+ def state_size(self):
+ # States are a flat list
+ # in reverse order of the cell stack.
+ # This allows to preserve the requirement
+ # `stack.state_size[0] == output_dim`.
+ # e.g. states of a 2-layer LSTM would be
+ # `[h2, c2, h1, c1]`
+ # (assuming one LSTM has states [h, c])
+ state_size = []
+ for cell in self.cells[::-1]:
+ if hasattr(cell.state_size, '__len__'):
+ state_size += list(cell.state_size)
+ else:
+ state_size.append(cell.state_size)
+ return tuple(state_size)
+
+ def call(self, inputs, states, **kwargs):
+ # Recover per-cell states.
+ nested_states = []
+ for cell in self.cells[::-1]:
+ if hasattr(cell.state_size, '__len__'):
+ nested_states.append(states[:len(cell.state_size)])
+ states = states[len(cell.state_size):]
+ else:
+ nested_states.append([states[0]])
+ states = states[1:]
+ nested_states = nested_states[::-1]
+
+ # Call the cells in order and store the returned states.
+ new_nested_states = []
+ for cell, states in zip(self.cells, nested_states):
+ inputs, states = cell.call(inputs, states, **kwargs)
+ new_nested_states.append(states)
+
+ # Format the new states as a flat list
+ # in reverse cell order.
+ states = []
+ for cell_states in new_nested_states[::-1]:
+ states += cell_states
+ return inputs, states
- Example:
+ def build(self, input_shape):
+ for cell in self.cells:
+ if isinstance(cell, Layer):
+ cell.build(input_shape)
+ if hasattr(cell.state_size, '__len__'):
+ output_dim = cell.state_size[0]
+ else:
+ output_dim = cell.state_size
+ input_shape = (input_shape[0], input_shape[1], output_dim)
+ self.built = True
- ```python
- # as the first layer in a Sequential model
- model = Sequential()
- model.add(LSTM(32, input_shape=(10, 64)))
- # now model.output_shape == (None, 32)
- # note: `None` is the batch dimension.
-
- # for subsequent layers, no need to specify the input size:
- model.add(LSTM(16))
-
- # to stack recurrent layers, you must use return_sequences=True
- # on any recurrent layer that feeds into another recurrent layer.
- # note that you only need to specify the input size on the first layer.
- model = Sequential()
- model.add(LSTM(64, input_dim=64, input_length=10, return_sequences=True))
- model.add(LSTM(32, return_sequences=True))
- model.add(LSTM(10))
- ```
+ def get_config(self):
+ cells = []
+ for cell in self.cells:
+ cells.append({
+ 'class_name': cell.__class__.__name__,
+ 'config': cell.get_config()
+ })
+ config = {'cells': cells}
+ base_config = super(StackedRNNCells, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
+ cells = []
+ for cell_config in config.pop('cells'):
+ cells.append(
+ deserialize_layer(cell_config, custom_objects=custom_objects))
+ return cls(cells, **config)
+
+ @property
+ def trainable_weights(self):
+ if not self.trainable:
+ return []
+ weights = []
+ for cell in self.cells:
+ if isinstance(cell, Layer):
+ weights += cell.trainable_weights
+ return weights
+
+ @property
+ def non_trainable_weights(self):
+ weights = []
+ for cell in self.cells:
+ if isinstance(cell, Layer):
+ weights += cell.non_trainable_weights
+ if not self.trainable:
+ trainable_weights = []
+ for cell in self.cells:
+ if isinstance(cell, Layer):
+ trainable_weights += cell.trainable_weights
+ return trainable_weights + weights
+ return weights
+
+ def get_weights(self):
+ """Retrieves the weights of the model.
+
+ Returns:
+ A flat list of Numpy arrays.
+ """
+ weights = []
+ for cell in self.cells:
+ if isinstance(cell, Layer):
+ weights += cell.weights
+ return K.batch_get_value(weights)
+
+ def set_weights(self, weights):
+ """Sets the weights of the model.
+
+ Arguments:
+ weights: A list of Numpy arrays with shapes and types matching
+ the output of `model.get_weights()`.
+ """
+ tuples = []
+ for cell in self.cells:
+ if isinstance(cell, Layer):
+ num_param = len(cell.weights)
+ weights = weights[:num_param]
+ for sw, w in zip(cell.weights, weights):
+ tuples.append((sw, w))
+ weights = weights[num_param:]
+ K.batch_set_value(tuples)
+
+ @property
+ def losses(self):
+ losses = []
+ for cell in self.cells:
+ if isinstance(cell, Layer):
+ cell_losses = cell.losses
+ losses += cell_losses
+ return losses
+
+ def get_losses_for(self, inputs=None):
+ losses = []
+ for cell in self.cells:
+ if isinstance(cell, Layer):
+ cell_losses = cell.get_losses_for(inputs)
+ losses += cell_losses
+ return losses
+
+
+class RNN(Layer):
+ """Base class for recurrent layers.
Arguments:
- weights: list of Numpy arrays to set as initial weights.
- The list should have 3 elements, of shapes:
- `[(input_dim, output_dim), (output_dim, output_dim), (output_dim,)]`.
- return_sequences: Boolean. Whether to return the last output
+ cell: A RNN cell instance. A RNN cell is a class that has:
+ - a `call(input_at_t, states_at_t)` method, returning
+ `(output_at_t, states_at_t_plus_1)`. The call method of the
+ cell can also take the optional argument `constants`, see
+ section "Note on passing external constants" below.
+ - a `state_size` attribute. This can be a single integer
+ (single state) in which case it is
+ the size of the recurrent state
+ (which should be the same as the size of the cell output).
+ This can also be a list/tuple of integers
+ (one size per state). In this case, the first entry
+ (`state_size[0]`) should be the same as
+ the size of the cell output.
+ It is also possible for `cell` to be a list of RNN cell instances,
+ in which cases the cells get stacked on after the other in the RNN,
+ implementing an efficient stacked RNN.
+ return_sequences: Boolean. Whether to return the last output.
in the output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state
in addition to the output.
@@ -137,21 +247,9 @@ class Recurrent(Layer):
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
- implementation: one of {0, 1, or 2}.
- If set to 0, the RNN will use
- an implementation that uses fewer, larger matrix products,
- thus running faster on CPU but consuming more memory.
- If set to 1, the RNN will use more matrix products,
- but smaller ones, thus running slower
- (may actually be faster on GPU) while consuming less memory.
- If set to 2 (LSTM/GRU only),
- the RNN will combine the input gate,
- the forget gate and the output gate into a single matrix,
- enabling more time-efficient parallelization on the GPU.
- Note: RNN dropout must be shared for all gates,
- resulting in a slightly reduced regularization.
input_dim: dimensionality of the input (integer).
- This argument (or alternatively, the keyword argument `input_shape`)
+ This argument (or alternatively,
+ the keyword argument `input_shape`)
is required when using this layer as the first layer in a model.
input_length: Length of input sequences, to be specified
when it is constant.
@@ -163,7 +261,7 @@ class Recurrent(Layer):
at the level of the first layer
(e.g. via the `input_shape` argument)
- Input shape:s
+ Input shape:
3D tensor with shape `(batch_size, timesteps, input_dim)`,
(Optional) 2D tensors with shape `(batch_size, output_dim)`.
@@ -178,7 +276,7 @@ class Recurrent(Layer):
# Masking
This layer supports masking for input data with a variable number
of timesteps. To introduce masks to your data,
- use an `Embedding` layer with the `mask_zero` parameter
+ use an [Embedding](embeddings.md) layer with the `mask_zero` parameter
set to `True`.
# Note on using statefulness in RNNs
@@ -212,42 +310,128 @@ class Recurrent(Layer):
calling `reset_states` with the keyword argument `states`. The value of
`states` should be a numpy array or list of numpy arrays representing
the initial state of the RNN layer.
+
+ # Note on passing external constants to RNNs
+ You can pass "external" constants to the cell using the `constants`
+ keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
+ requires that the `cell.call` method accepts the same keyword argument
+ `constants`. Such constants can be used to condition the cell
+ transformation on additional static inputs (not changing over time),
+ a.k.a. an attention mechanism.
+
+ Examples:
+
+ ```python
+ # First, let's define a RNN Cell, as a layer subclass.
+
+ class MinimalRNNCell(keras.layers.Layer):
+
+ def __init__(self, units, **kwargs):
+ self.units = units
+ self.state_size = units
+ super(MinimalRNNCell, self).__init__(**kwargs)
+
+ def build(self, input_shape):
+ self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
+ initializer='uniform',
+ name='kernel')
+ self.recurrent_kernel = self.add_weight(
+ shape=(self.units, self.units),
+ initializer='uniform',
+ name='recurrent_kernel')
+ self.built = True
+
+ def call(self, inputs, states):
+ prev_output = states[0]
+ h = K.dot(inputs, self.kernel)
+ output = h + K.dot(prev_output, self.recurrent_kernel)
+ return output, [output]
+
+ # Let's use this cell in a RNN layer:
+
+ cell = MinimalRNNCell(32)
+ x = keras.Input((None, 5))
+ layer = RNN(cell)
+ y = layer(x)
+
+ # Here's how to use the cell to build a stacked RNN:
+
+ cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
+ x = keras.Input((None, 5))
+ layer = RNN(cells)
+ y = layer(x)
+ ```
"""
def __init__(self,
+ cell,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
- implementation=0,
+ activity_regularizer=None,
**kwargs):
- super(Recurrent, self).__init__(**kwargs)
+ if isinstance(cell, (list, tuple)):
+ cell = StackedRNNCells(cell)
+ if not hasattr(cell, 'call'):
+ raise ValueError('`cell` should have a `call` method. '
+ 'The RNN was passed:', cell)
+ if not hasattr(cell, 'state_size'):
+ raise ValueError('The RNN cell should have '
+ 'an attribute `state_size` '
+ '(tuple of integers, '
+ 'one integer per RNN state).')
+ super(RNN, self).__init__(
+ activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
+ self.cell = cell
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards
self.stateful = stateful
self.unroll = unroll
- self.implementation = implementation
+
self.supports_masking = True
self.input_spec = [InputSpec(ndim=3)]
self.state_spec = None
- self.dropout = 0
- self.recurrent_dropout = 0
+ self._states = None
+ self.constants_spec = None
+ self._num_constants = None
+
+ @property
+ def states(self):
+ if self._states is None:
+ if isinstance(self.cell.state_size, int):
+ num_states = 1
+ else:
+ num_states = len(self.cell.state_size)
+ return [None for _ in range(num_states)]
+ return self._states
+
+ @states.setter
+ def states(self, states):
+ self._states = states
def _compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
input_shape = tensor_shape.TensorShape(input_shape).as_list()
+
+ if hasattr(self.cell.state_size, '__len__'):
+ output_dim = self.cell.state_size[0]
+ else:
+ output_dim = self.cell.state_size
+
if self.return_sequences:
- output_shape = (input_shape[0], input_shape[1], self.units)
+ output_shape = (input_shape[0], input_shape[1], output_dim)
else:
- output_shape = (input_shape[0], self.units)
+ output_shape = (input_shape[0], output_dim)
if self.return_state:
- state_shape = [tensor_shape.TensorShape(
- (input_shape[0], self.units)) for _ in self.states]
- return [tensor_shape.TensorShape(output_shape)] + state_shape
+ state_shape = [(input_shape[0], output_dim) for _ in self.states]
+ output_shape = [output_shape] + state_shape
+ else:
+ output_shape = output_shape
return tensor_shape.TensorShape(output_shape)
def compute_mask(self, inputs, mask):
@@ -257,82 +441,123 @@ class Recurrent(Layer):
if self.return_state:
state_mask = [None for _ in self.states]
return [output_mask] + state_mask
- return output_mask
+ else:
+ return output_mask
- def step(self, inputs, states):
- raise NotImplementedError
+ def build(self, input_shape):
+ # Note input_shape will be list of shapes of initial states and
+ # constants if these are passed in __call__.
+ if self._num_constants is not None:
+ constants_shape = input_shape[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
+ else:
+ constants_shape = None
- def get_constants(self, inputs, training=None):
- return []
+ if isinstance(input_shape, list):
+ input_shape = input_shape[0]
+ input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
+
+ batch_size = input_shape[0] if self.stateful else None
+ input_dim = input_shape[-1]
+ self.input_spec[0] = InputSpec(shape=(batch_size, None, input_dim))
+
+ # allow cell (if layer) to build before we set or validate state_spec
+ if isinstance(self.cell, Layer):
+ step_input_shape = (input_shape[0],) + input_shape[2:]
+ if constants_shape is not None:
+ self.cell.build([step_input_shape] + constants_shape)
+ else:
+ self.cell.build(step_input_shape)
+
+ # set or validate state_spec
+ if hasattr(self.cell.state_size, '__len__'):
+ state_size = list(self.cell.state_size)
+ else:
+ state_size = [self.cell.state_size]
+
+ if self.state_spec is not None:
+ # initial_state was passed in call, check compatibility
+ if [spec.shape[-1] for spec in self.state_spec] != state_size:
+ raise ValueError(
+ 'An initial_state was passed that is not compatible with '
+ '`cell.state_size`. Received `state_spec`={}; '
+ 'However `cell.state_size` is '
+ '{}'.format(self.state_spec, self.cell.state_size))
+ else:
+ self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size]
+ if self.stateful:
+ self.reset_states()
def get_initial_state(self, inputs):
# build an all-zero tensor of shape (samples, output_dim)
initial_state = K.zeros_like(inputs) # (samples, timesteps, input_dim)
initial_state = K.sum(initial_state, axis=(1, 2)) # (samples,)
initial_state = K.expand_dims(initial_state) # (samples, 1)
- initial_state = K.tile(initial_state, [1,
- self.units]) # (samples, output_dim)
- initial_state = [initial_state for _ in range(len(self.states))]
- return initial_state
-
- def preprocess_input(self, inputs, training=None):
- return inputs
+ if hasattr(self.cell.state_size, '__len__'):
+ return [K.tile(initial_state, [1, dim]) for dim in self.cell.state_size]
+ else:
+ return [K.tile(initial_state, [1, self.cell.state_size])]
- def __call__(self, inputs, initial_state=None, **kwargs):
- if (isinstance(inputs, (list, tuple)) and
- len(inputs) > 1
- and initial_state is None):
- initial_state = inputs[1:]
- inputs = inputs[0]
+ def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
+ inputs, initial_state, constants = self._standardize_args(
+ inputs, initial_state, constants)
- # If `initial_state` is specified,
- # and if it a Keras tensor,
- # then add it to the inputs and temporarily
- # modify the input spec to include the state.
- if initial_state is None:
- return super(Recurrent, self).__call__(inputs, **kwargs)
+ if initial_state is None and constants is None:
+ return super(RNN, self).__call__(inputs, **kwargs)
- if not isinstance(initial_state, (list, tuple)):
- initial_state = [initial_state]
+ # If any of `initial_state` or `constants` are specified and are Keras
+ # tensors, then add them to the inputs and temporarily modify the
+ # input_spec to include them.
- is_keras_tensor = hasattr(initial_state[0], '_keras_history')
- for tensor in initial_state:
+ additional_inputs = []
+ additional_specs = []
+ if initial_state is not None:
+ kwargs['initial_state'] = initial_state
+ additional_inputs += initial_state
+ self.state_spec = [
+ InputSpec(shape=K.int_shape(state)) for state in initial_state
+ ]
+ additional_specs += self.state_spec
+ if constants is not None:
+ kwargs['constants'] = constants
+ additional_inputs += constants
+ self.constants_spec = [
+ InputSpec(shape=K.int_shape(constant)) for constant in constants
+ ]
+ self._num_constants = len(constants)
+ additional_specs += self.constants_spec
+ # at this point additional_inputs cannot be empty
+ is_keras_tensor = hasattr(additional_inputs[0], '_keras_history')
+ for tensor in additional_inputs:
if hasattr(tensor, '_keras_history') != is_keras_tensor:
- raise ValueError('The initial state of an RNN layer cannot be'
- ' specified with a mix of Keras tensors and'
- ' non-Keras tensors')
+ raise ValueError('The initial state or constants of an RNN'
+ ' layer cannot be specified with a mix of'
+ ' Keras tensors and non-Keras tensors')
if is_keras_tensor:
- # Compute the full input spec, including state
- input_spec = self.input_spec
- state_spec = self.state_spec
- if not isinstance(input_spec, list):
- input_spec = [input_spec]
- if not isinstance(state_spec, list):
- state_spec = [state_spec]
- self.input_spec = input_spec + state_spec
-
- # Compute the full inputs, including state
- inputs = [inputs] + list(initial_state)
-
- # Perform the call
- output = super(Recurrent, self).__call__(inputs, **kwargs)
-
- # Restore original input spec
- self.input_spec = input_spec
+ # Compute the full input spec, including state and constants
+ full_input = [inputs] + additional_inputs
+ full_input_spec = self.input_spec + additional_specs
+ # Perform the call with temporarily replaced input_spec
+ original_input_spec = self.input_spec
+ self.input_spec = full_input_spec
+ output = super(RNN, self).__call__(full_input, **kwargs)
+ self.input_spec = original_input_spec
return output
else:
- kwargs['initial_state'] = initial_state
- return super(Recurrent, self).__call__(inputs, **kwargs)
-
- def call(self, inputs, mask=None, training=None, initial_state=None):
+ return super(RNN, self).__call__(inputs, **kwargs)
+
+ def call(self,
+ inputs,
+ mask=None,
+ training=None,
+ initial_state=None,
+ constants=None):
# input shape: `(samples, time (padded with zeros), input_dim)`
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
if isinstance(inputs, list):
- initial_state = inputs[1:]
inputs = inputs[0]
- elif initial_state is not None:
+ if initial_state is not None:
pass
elif self.stateful:
initial_state = self.states
@@ -343,13 +568,14 @@ class Recurrent(Layer):
mask = mask[0]
if len(initial_state) != len(self.states):
- raise ValueError('Layer has ' + str(len(self.states)) +
- ' states but was passed ' + str(len(initial_state)) +
- ' initial states.')
+ raise ValueError(
+ 'Layer has ' + str(len(self.states)) + ' states but was passed ' +
+ str(len(initial_state)) + ' initial states.')
input_shape = K.int_shape(inputs)
- if self.unroll and input_shape[1] is None:
+ timesteps = input_shape[1]
+ if self.unroll and timesteps in [None, 1]:
raise ValueError('Cannot unroll a RNN if the '
- 'time dimension is undefined. \n'
+ 'time dimension is undefined or equal to 1. \n'
'- If using a Sequential model, '
'specify the time dimension by passing '
'an `input_shape` or `batch_input_shape` '
@@ -359,15 +585,31 @@ class Recurrent(Layer):
'- If using the functional API, specify '
'the time dimension by passing a `shape` '
'or `batch_shape` argument to your Input layer.')
- constants = self.get_constants(inputs, training=None)
- preprocessed_input = self.preprocess_input(inputs, training=None)
+
+ kwargs = {}
+ if has_arg(self.cell.call, 'training'):
+ kwargs['training'] = training
+
+ if constants:
+ if not has_arg(self.cell.call, 'constants'):
+ raise ValueError('RNN cell does not support constants')
+
+ def step(inputs, states):
+ constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
+ states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
+ return self.cell.call(inputs, states, constants=constants, **kwargs)
+ else:
+
+ def step(inputs, states):
+ return self.cell.call(inputs, states, **kwargs)
+
last_output, outputs, states = K.rnn(
- self.step,
- preprocessed_input,
+ step,
+ inputs,
initial_state,
+ constants=constants,
go_backwards=self.go_backwards,
mask=mask,
- constants=constants,
unroll=self.unroll)
if self.stateful:
updates = []
@@ -375,21 +617,63 @@ class Recurrent(Layer):
updates.append((self.states[i], states[i]))
self.add_update(updates, inputs)
- # Properly set learning phase
- if 0 < self.dropout + self.recurrent_dropout:
- last_output._uses_learning_phase = True
- outputs._uses_learning_phase = True
+ if self.return_sequences:
+ output = outputs
+ else:
+ output = last_output
- if not self.return_sequences:
- outputs = last_output
+ # Properly set learning phase
+ if getattr(last_output, '_uses_learning_phase', False):
+ output._uses_learning_phase = True
if self.return_state:
if not isinstance(states, (list, tuple)):
states = [states]
else:
states = list(states)
- return [outputs] + states
- return outputs
+ return [output] + states
+ else:
+ return output
+
+ def _standardize_args(self, inputs, initial_state, constants):
+ """Standardize `__call__` arguments to a single list of tensor inputs.
+
+ When running a model loaded from file, the input tensors
+ `initial_state` and `constants` can be passed to `RNN.__call__` as part
+ of `inputs` instead of by the dedicated keyword arguments. This method
+ makes sure the arguments are separated and that `initial_state` and
+ `constants` are lists of tensors (or None).
+
+ Arguments:
+ inputs: tensor or list/tuple of tensors
+ initial_state: tensor or list of tensors or None
+ constants: tensor or list of tensors or None
+
+ Returns:
+ inputs: tensor
+ initial_state: list of tensors or None
+ constants: list of tensors or None
+ """
+ if isinstance(inputs, list):
+ assert initial_state is None and constants is None
+ if self._num_constants is not None:
+ constants = inputs[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
+ inputs = inputs[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
+ if len(inputs) > 1:
+ initial_state = inputs[1:]
+ inputs = inputs[0]
+
+ def to_list_or_none(x):
+ if x is None or isinstance(x, list):
+ return x
+ if isinstance(x, tuple):
+ return list(x)
+ return [x]
+
+ initial_state = to_list_or_none(initial_state)
+ constants = to_list_or_none(constants)
+
+ return inputs, initial_state, constants
def reset_states(self, states=None):
if not self.stateful:
@@ -408,10 +692,19 @@ class Recurrent(Layer):
'`batch_shape` argument to your Input layer.')
# initialize state if None
if self.states[0] is None:
- self.states = [K.zeros((batch_size, self.units)) for _ in self.states]
+ if hasattr(self.cell.state_size, '__len__'):
+ self.states = [
+ K.zeros((batch_size, dim)) for dim in self.cell.state_size
+ ]
+ else:
+ self.states = [K.zeros((batch_size, self.cell.state_size))]
elif states is None:
- for state in self.states:
- K.set_value(state, np.zeros((batch_size, self.units)))
+ if hasattr(self.cell.state_size, '__len__'):
+ for state, dim in zip(self.states, self.cell.state_size):
+ K.set_value(state, np.zeros((batch_size, dim)))
+ else:
+ K.set_value(self.states[0], np.zeros((batch_size,
+ self.cell.state_size)))
else:
if not isinstance(states, (list, tuple)):
states = [states]
@@ -421,11 +714,16 @@ class Recurrent(Layer):
'but it received ' + str(len(states)) +
' state values. Input received: ' + str(states))
for index, (value, state) in enumerate(zip(states, self.states)):
- if value.shape != (batch_size, self.units):
- raise ValueError('State ' + str(index) +
- ' is incompatible with layer ' + self.name +
- ': expected shape=' + str((batch_size, self.units)) +
- ', found shape=' + str(value.shape))
+ if hasattr(self.cell.state_size, '__len__'):
+ dim = self.cell.state_size[index]
+ else:
+ dim = self.cell.state_size
+ if value.shape != (batch_size, dim):
+ raise ValueError(
+ 'State ' + str(index) + ' is incompatible with layer ' +
+ self.name + ': expected shape=' + str(
+ (batch_size, dim)) + ', found shape=' + str(value.shape))
+ # TODO(fchollet): consider batch calls to `set_value`.
K.set_value(state, value)
def get_config(self):
@@ -434,51 +732,94 @@ class Recurrent(Layer):
'return_state': self.return_state,
'go_backwards': self.go_backwards,
'stateful': self.stateful,
- 'unroll': self.unroll,
- 'implementation': self.implementation
+ 'unroll': self.unroll
}
- base_config = super(Recurrent, self).get_config()
+ if self._num_constants is not None:
+ config['num_constants'] = self._num_constants
+
+ cell_config = self.cell.get_config()
+ config['cell'] = {
+ 'class_name': self.cell.__class__.__name__,
+ 'config': cell_config
+ }
+ base_config = super(RNN, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
+ cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects)
+ num_constants = config.pop('num_constants', None)
+ layer = cls(cell, **config)
+ layer._num_constants = num_constants
+ return layer
+
+ @property
+ def trainable_weights(self):
+ if isinstance(self.cell, Layer):
+ return self.cell.trainable_weights
+ return []
-class SimpleRNN(Recurrent):
- """Fully-connected RNN where the output is to be fed back to input.
+ @property
+ def non_trainable_weights(self):
+ if isinstance(self.cell, Layer):
+ return self.cell.non_trainable_weights
+ return []
+
+ @property
+ def losses(self):
+ if isinstance(self.cell, Layer):
+ return self.cell.losses
+ return []
+
+ def get_losses_for(self, inputs=None):
+ if isinstance(self.cell, Layer):
+ cell_losses = self.cell.get_losses_for(inputs)
+ return cell_losses + super(RNN, self).get_losses_for(inputs)
+ return super(RNN, self).get_losses_for(inputs)
+
+
+class SimpleRNNCell(Layer):
+ """Cell class for SimpleRNN.
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use.
- If you don't specify anything, no activation is applied
+ activation: Activation function to use
+ (see [activations](../activations.md)).
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
- used for the linear transformation of the inputs..
+ used for the linear transformation of the inputs.
+ (see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
- used for the linear transformation of the recurrent state..
- bias_initializer: Initializer for the bias vector.
+ used for the linear transformation of the recurrent state.
+ (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector
+ (see [initializers](../initializers.md)).
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix.
+ the `kernel` weights matrix
+ (see [regularizer](../regularizers.md)).
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix.
- bias_regularizer: Regularizer function applied to the bias vector.
- activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation")..
+ the `recurrent_kernel` weights matrix
+ (see [regularizer](../regularizers.md)).
+ bias_regularizer: Regularizer function applied to the bias vector
+ (see [regularizer](../regularizers.md)).
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix.
+ the `kernel` weights matrix
+ (see [constraints](../constraints.md)).
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix.
- bias_constraint: Constraint function applied to the bias vector.
+ the `recurrent_kernel` weights matrix
+ (see [constraints](../constraints.md)).
+ bias_constraint: Constraint function applied to the bias vector
+ (see [constraints](../constraints.md)).
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
-
- References:
- - [A Theoretically Grounded Application of Dropout in Recurrent Neural
- Networks](http://arxiv.org/abs/1512.05287)
"""
def __init__(self,
@@ -491,15 +832,13 @@ class SimpleRNN(Recurrent):
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
- activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
**kwargs):
- super(SimpleRNN, self).__init__(
- activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
+ super(SimpleRNNCell, self).__init__(**kwargs)
self.units = units
self.activation = activations.get(activation)
self.use_bias = use_bias
@@ -518,23 +857,13 @@ class SimpleRNN(Recurrent):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
- self.state_spec = InputSpec(shape=(None, self.units))
+ self.state_size = self.units
+ self._dropout_mask = None
+ self._recurrent_dropout_mask = None
def build(self, input_shape):
- if isinstance(input_shape, list):
- input_shape = input_shape[0]
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
-
- batch_size = input_shape[0] if self.stateful else None
- self.input_dim = input_shape[2]
- self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim))
-
- self.states = [None]
- if self.stateful:
- self.reset_states()
-
self.kernel = self.add_weight(
- shape=(self.input_dim, self.units),
+ shape=(input_shape[-1], self.units),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
@@ -556,146 +885,327 @@ class SimpleRNN(Recurrent):
self.bias = None
self.built = True
- def preprocess_input(self, inputs, training=None):
- if self.implementation > 0:
- return inputs
+ def _generate_dropout_mask(self, inputs, training=None):
+ if 0 < self.dropout < 1:
+ ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
+
+ def dropped_inputs():
+ return K.dropout(ones, self.dropout)
+
+ self._dropout_mask = K.in_train_phase(
+ dropped_inputs, ones, training=training)
else:
- input_shape = inputs.get_shape().as_list()
- input_dim = input_shape[2]
- timesteps = input_shape[1]
- return _time_distributed_dense(
- inputs,
- self.kernel,
- self.bias,
- self.dropout,
- input_dim,
- self.units,
- timesteps,
- training=training)
+ self._dropout_mask = None
- def step(self, inputs, states):
- if self.implementation == 0:
- h = inputs
+ def _generate_recurrent_dropout_mask(self, inputs, training=None):
+ if 0 < self.recurrent_dropout < 1:
+ ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
+ ones = K.tile(ones, (1, self.units))
+
+ def dropped_inputs():
+ return K.dropout(ones, self.dropout)
+
+ self._recurrent_dropout_mask = K.in_train_phase(
+ dropped_inputs, ones, training=training)
else:
- if 0 < self.dropout < 1:
- h = K.dot(inputs * states[1], self.kernel)
- else:
- h = K.dot(inputs, self.kernel)
- if self.bias is not None:
- h = K.bias_add(h, self.bias)
+ self._recurrent_dropout_mask = None
+ def call(self, inputs, states, training=None):
prev_output = states[0]
- if 0 < self.recurrent_dropout < 1:
- prev_output *= states[2]
+ dp_mask = self._dropout_mask
+ rec_dp_mask = self._recurrent_dropout_mask
+
+ if dp_mask is not None:
+ h = K.dot(inputs * dp_mask, self.kernel)
+ else:
+ h = K.dot(inputs, self.kernel)
+ if self.bias is not None:
+ h = K.bias_add(h, self.bias)
+
+ if rec_dp_mask is not None:
+ prev_output *= rec_dp_mask
output = h + K.dot(prev_output, self.recurrent_kernel)
if self.activation is not None:
output = self.activation(output)
# Properly set learning phase on output tensor.
if 0 < self.dropout + self.recurrent_dropout:
- output._uses_learning_phase = True
+ if training is None:
+ output._uses_learning_phase = True
return output, [output]
- def get_constants(self, inputs, training=None):
- constants = []
- if self.implementation != 0 and 0 < self.dropout < 1:
- input_shape = K.int_shape(inputs)
- input_dim = input_shape[-1]
- ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
- ones = K.tile(ones, (1, int(input_dim)))
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
+class SimpleRNN(RNN):
+ """Fully-connected RNN where the output is to be fed back to input.
- dp_mask = K.in_train_phase(dropped_inputs, ones, training=training)
- constants.append(dp_mask)
- else:
- constants.append(K.cast_to_floatx(1.))
+ Arguments:
+ units: Positive integer, dimensionality of the output space.
+ activation: Activation function to use
+ (see [activations](../activations.md)).
+ If you pass None, no activation is applied
+ (ie. "linear" activation: `a(x) = x`).
+ use_bias: Boolean, whether the layer uses a bias vector.
+ kernel_initializer: Initializer for the `kernel` weights matrix,
+ used for the linear transformation of the inputs.
+ (see [initializers](../initializers.md)).
+ recurrent_initializer: Initializer for the `recurrent_kernel`
+ weights matrix,
+ used for the linear transformation of the recurrent state.
+ (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector
+ (see [initializers](../initializers.md)).
+ kernel_regularizer: Regularizer function applied to
+ the `kernel` weights matrix
+ (see [regularizer](../regularizers.md)).
+ recurrent_regularizer: Regularizer function applied to
+ the `recurrent_kernel` weights matrix
+ (see [regularizer](../regularizers.md)).
+ bias_regularizer: Regularizer function applied to the bias vector
+ (see [regularizer](../regularizers.md)).
+ activity_regularizer: Regularizer function applied to
+ the output of the layer (its "activation").
+ (see [regularizer](../regularizers.md)).
+ kernel_constraint: Constraint function applied to
+ the `kernel` weights matrix
+ (see [constraints](../constraints.md)).
+ recurrent_constraint: Constraint function applied to
+ the `recurrent_kernel` weights matrix
+ (see [constraints](../constraints.md)).
+ bias_constraint: Constraint function applied to the bias vector
+ (see [constraints](../constraints.md)).
+ dropout: Float between 0 and 1.
+ Fraction of the units to drop for
+ the linear transformation of the inputs.
+ recurrent_dropout: Float between 0 and 1.
+ Fraction of the units to drop for
+ the linear transformation of the recurrent state.
+ return_sequences: Boolean. Whether to return the last output.
+ in the output sequence, or the full sequence.
+ return_state: Boolean. Whether to return the last state
+ in addition to the output.
+ go_backwards: Boolean (default False).
+ If True, process the input sequence backwards and return the
+ reversed sequence.
+ stateful: Boolean (default False). If True, the last state
+ for each sample at index i in a batch will be used as initial
+ state for the sample of index i in the following batch.
+ unroll: Boolean (default False).
+ If True, the network will be unrolled,
+ else a symbolic loop will be used.
+ Unrolling can speed-up a RNN,
+ although it tends to be more memory-intensive.
+ Unrolling is only suitable for short sequences.
+ """
- if 0 < self.recurrent_dropout < 1:
- ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
- ones = K.tile(ones, (1, self.units))
+ def __init__(self,
+ units,
+ activation='tanh',
+ use_bias=True,
+ kernel_initializer='glorot_uniform',
+ recurrent_initializer='orthogonal',
+ bias_initializer='zeros',
+ kernel_regularizer=None,
+ recurrent_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ kernel_constraint=None,
+ recurrent_constraint=None,
+ bias_constraint=None,
+ dropout=0.,
+ recurrent_dropout=0.,
+ return_sequences=False,
+ return_state=False,
+ go_backwards=False,
+ stateful=False,
+ unroll=False,
+ **kwargs):
+ if 'implementation' in kwargs:
+ kwargs.pop('implementation')
+ logging.warning('The `implementation` argument '
+ 'in `SimpleRNN` has been deprecated. '
+ 'Please remove it from your layer call.')
+ cell = SimpleRNNCell(
+ units,
+ activation=activation,
+ use_bias=use_bias,
+ kernel_initializer=kernel_initializer,
+ recurrent_initializer=recurrent_initializer,
+ bias_initializer=bias_initializer,
+ kernel_regularizer=kernel_regularizer,
+ recurrent_regularizer=recurrent_regularizer,
+ bias_regularizer=bias_regularizer,
+ kernel_constraint=kernel_constraint,
+ recurrent_constraint=recurrent_constraint,
+ bias_constraint=bias_constraint,
+ dropout=dropout,
+ recurrent_dropout=recurrent_dropout)
+ super(SimpleRNN, self).__init__(
+ cell,
+ return_sequences=return_sequences,
+ return_state=return_state,
+ go_backwards=go_backwards,
+ stateful=stateful,
+ unroll=unroll,
+ activity_regularizer=regularizers.get(activity_regularizer),
+ **kwargs)
+ # self.activity_regularizer = regularizers.get(activity_regularizer)
- def dropped_inputs(): # pylint: disable=function-redefined
- return K.dropout(ones, self.recurrent_dropout)
+ def call(self, inputs, mask=None, training=None, initial_state=None):
+ self.cell._generate_dropout_mask(inputs, training=training)
+ self.cell._generate_recurrent_dropout_mask(inputs, training=training)
+ return super(SimpleRNN, self).call(
+ inputs, mask=mask, training=training, initial_state=initial_state)
- rec_dp_mask = K.in_train_phase(dropped_inputs, ones, training=training)
- constants.append(rec_dp_mask)
- else:
- constants.append(K.cast_to_floatx(1.))
- return constants
+ @property
+ def units(self):
+ return self.cell.units
+
+ @property
+ def activation(self):
+ return self.cell.activation
+
+ @property
+ def use_bias(self):
+ return self.cell.use_bias
+
+ @property
+ def kernel_initializer(self):
+ return self.cell.kernel_initializer
+
+ @property
+ def recurrent_initializer(self):
+ return self.cell.recurrent_initializer
+
+ @property
+ def bias_initializer(self):
+ return self.cell.bias_initializer
+
+ @property
+ def kernel_regularizer(self):
+ return self.cell.kernel_regularizer
+
+ @property
+ def recurrent_regularizer(self):
+ return self.cell.recurrent_regularizer
+
+ @property
+ def bias_regularizer(self):
+ return self.cell.bias_regularizer
+
+ @property
+ def kernel_constraint(self):
+ return self.cell.kernel_constraint
+
+ @property
+ def recurrent_constraint(self):
+ return self.cell.recurrent_constraint
+
+ @property
+ def bias_constraint(self):
+ return self.cell.bias_constraint
+
+ @property
+ def dropout(self):
+ return self.cell.dropout
+
+ @property
+ def recurrent_dropout(self):
+ return self.cell.recurrent_dropout
def get_config(self):
config = {
- 'units': self.units,
- 'activation': activations.serialize(self.activation),
- 'use_bias': self.use_bias,
- 'kernel_initializer': initializers.serialize(self.kernel_initializer),
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
- 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint': constraints.serialize(self.kernel_constraint),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint),
- 'dropout': self.dropout,
- 'recurrent_dropout': self.recurrent_dropout
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout
}
base_config = super(SimpleRNN, self).get_config()
+ del base_config['cell']
return dict(list(base_config.items()) + list(config.items()))
+ @classmethod
+ def from_config(cls, config):
+ if 'implementation' in config:
+ config.pop('implementation')
+ return cls(**config)
-class GRU(Recurrent):
- """Gated Recurrent Unit - Cho et al.
- 2014.
+class GRUCell(Layer):
+ """Cell class for the GRU layer.
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use.
+ activation: Activation function to use
+ (see [activations](../activations.md)).
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
- for the recurrent step.
+ for the recurrent step
+ (see [activations](../activations.md)).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
- used for the linear transformation of the inputs..
+ used for the linear transformation of the inputs.
+ (see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
- used for the linear transformation of the recurrent state..
- bias_initializer: Initializer for the bias vector.
+ used for the linear transformation of the recurrent state.
+ (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector
+ (see [initializers](../initializers.md)).
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix.
+ the `kernel` weights matrix
+ (see [regularizer](../regularizers.md)).
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix.
- bias_regularizer: Regularizer function applied to the bias vector.
- activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation")..
+ the `recurrent_kernel` weights matrix
+ (see [regularizer](../regularizers.md)).
+ bias_regularizer: Regularizer function applied to the bias vector
+ (see [regularizer](../regularizers.md)).
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix.
+ the `kernel` weights matrix
+ (see [constraints](../constraints.md)).
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix.
- bias_constraint: Constraint function applied to the bias vector.
+ the `recurrent_kernel` weights matrix
+ (see [constraints](../constraints.md)).
+ bias_constraint: Constraint function applied to the bias vector
+ (see [constraints](../constraints.md)).
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
-
- References:
- - [On the Properties of Neural Machine Translation: Encoder-Decoder
- Approaches](https://arxiv.org/abs/1409.1259)
- - [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence
- Modeling](http://arxiv.org/abs/1412.3555v1)
- - [A Theoretically Grounded Application of Dropout in Recurrent Neural
- Networks](http://arxiv.org/abs/1512.05287)
+ implementation: Implementation mode, either 1 or 2.
+ Mode 1 will structure its operations as a larger number of
+ smaller dot products and additions, whereas mode 2 will
+ batch them into fewer, larger operations. These modes will
+ have different performance profiles on different hardware and
+ for different applications.
"""
def __init__(self,
@@ -709,15 +1219,14 @@ class GRU(Recurrent):
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
- activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
+ implementation=1,
**kwargs):
- super(GRU, self).__init__(
- activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
+ super(GRUCell, self).__init__(**kwargs)
self.units = units
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
@@ -737,22 +1246,15 @@ class GRU(Recurrent):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
- self.state_spec = InputSpec(shape=(None, self.units))
+ self.implementation = implementation
+ self.state_size = self.units
+ self._dropout_mask = None
+ self._recurrent_dropout_mask = None
def build(self, input_shape):
- if isinstance(input_shape, list):
- input_shape = input_shape[0]
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- batch_size = input_shape[0] if self.stateful else None
- self.input_dim = input_shape[2]
- self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim))
-
- self.states = [None]
- if self.stateful:
- self.reset_states()
-
+ input_dim = input_shape[-1]
self.kernel = self.add_weight(
- shape=(self.input_dim, self.units * 3),
+ shape=(input_dim, self.units * 3),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
@@ -792,89 +1294,83 @@ class GRU(Recurrent):
self.bias_h = None
self.built = True
- def preprocess_input(self, inputs, training=None):
- if self.implementation == 0:
- input_shape = inputs.get_shape().as_list()
- input_dim = input_shape[2]
- timesteps = input_shape[1]
-
- x_z = _time_distributed_dense(
- inputs,
- self.kernel_z,
- self.bias_z,
- self.dropout,
- input_dim,
- self.units,
- timesteps,
- training=training)
- x_r = _time_distributed_dense(
- inputs,
- self.kernel_r,
- self.bias_r,
- self.dropout,
- input_dim,
- self.units,
- timesteps,
- training=training)
- x_h = _time_distributed_dense(
- inputs,
- self.kernel_h,
- self.bias_h,
- self.dropout,
- input_dim,
- self.units,
- timesteps,
- training=training)
- return K.concatenate([x_z, x_r, x_h], axis=2)
- else:
- return inputs
-
- def get_constants(self, inputs, training=None):
- constants = []
- if self.implementation != 0 and 0 < self.dropout < 1:
- input_shape = K.int_shape(inputs)
- input_dim = input_shape[-1]
- ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
- ones = K.tile(ones, (1, int(input_dim)))
+ def _generate_dropout_mask(self, inputs, training=None):
+ if 0 < self.dropout < 1:
+ ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
def dropped_inputs():
return K.dropout(ones, self.dropout)
- dp_mask = [
+ self._dropout_mask = [
K.in_train_phase(dropped_inputs, ones, training=training)
for _ in range(3)
]
- constants.append(dp_mask)
else:
- constants.append([K.cast_to_floatx(1.) for _ in range(3)])
+ self._dropout_mask = None
+ def _generate_recurrent_dropout_mask(self, inputs, training=None):
if 0 < self.recurrent_dropout < 1:
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
ones = K.tile(ones, (1, self.units))
- def dropped_inputs(): # pylint: disable=function-redefined
- return K.dropout(ones, self.recurrent_dropout)
+ def dropped_inputs():
+ return K.dropout(ones, self.dropout)
- rec_dp_mask = [
+ self._recurrent_dropout_mask = [
K.in_train_phase(dropped_inputs, ones, training=training)
for _ in range(3)
]
- constants.append(rec_dp_mask)
else:
- constants.append([K.cast_to_floatx(1.) for _ in range(3)])
- return constants
+ self._recurrent_dropout_mask = None
- def step(self, inputs, states):
+ def call(self, inputs, states, training=None):
h_tm1 = states[0] # previous memory
- dp_mask = states[1] # dropout matrices for recurrent units
- rec_dp_mask = states[2]
- if self.implementation == 2:
- matrix_x = K.dot(inputs * dp_mask[0], self.kernel)
+ # dropout matrices for input units
+ dp_mask = self._dropout_mask
+ # dropout matrices for recurrent units
+ rec_dp_mask = self._recurrent_dropout_mask
+
+ if self.implementation == 1:
+ if 0. < self.dropout < 1.:
+ inputs_z = inputs * dp_mask[0]
+ inputs_r = inputs * dp_mask[1]
+ inputs_h = inputs * dp_mask[2]
+ else:
+ inputs_z = inputs
+ inputs_r = inputs
+ inputs_h = inputs
+ x_z = K.dot(inputs_z, self.kernel_z)
+ x_r = K.dot(inputs_r, self.kernel_r)
+ x_h = K.dot(inputs_h, self.kernel_h)
+ if self.use_bias:
+ x_z = K.bias_add(x_z, self.bias_z)
+ x_r = K.bias_add(x_r, self.bias_r)
+ x_h = K.bias_add(x_h, self.bias_h)
+
+ if 0. < self.recurrent_dropout < 1.:
+ h_tm1_z = h_tm1 * rec_dp_mask[0]
+ h_tm1_r = h_tm1 * rec_dp_mask[1]
+ h_tm1_h = h_tm1 * rec_dp_mask[2]
+ else:
+ h_tm1_z = h_tm1
+ h_tm1_r = h_tm1
+ h_tm1_h = h_tm1
+ z = self.recurrent_activation(
+ x_z + K.dot(h_tm1_z, self.recurrent_kernel_z))
+ r = self.recurrent_activation(
+ x_r + K.dot(h_tm1_r, self.recurrent_kernel_r))
+
+ hh = self.activation(x_h + K.dot(r * h_tm1_h, self.recurrent_kernel_h))
+ else:
+ if 0. < self.dropout < 1.:
+ inputs *= dp_mask[0]
+ matrix_x = K.dot(inputs, self.kernel)
if self.use_bias:
matrix_x = K.bias_add(matrix_x, self.bias)
- matrix_inner = K.dot(h_tm1 * rec_dp_mask[0],
- self.recurrent_kernel[:, :2 * self.units])
+ if 0. < self.recurrent_dropout < 1.:
+ h_tm1 *= rec_dp_mask[0]
+ matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units])
x_z = matrix_x[:, :self.units]
x_r = matrix_x[:, self.units:2 * self.units]
@@ -885,116 +1381,323 @@ class GRU(Recurrent):
r = self.recurrent_activation(x_r + recurrent_r)
x_h = matrix_x[:, 2 * self.units:]
- recurrent_h = K.dot(r * h_tm1 * rec_dp_mask[0],
- self.recurrent_kernel[:, 2 * self.units:])
+ recurrent_h = K.dot(r * h_tm1, self.recurrent_kernel[:, 2 * self.units:])
hh = self.activation(x_h + recurrent_h)
- else:
- if self.implementation == 0:
- x_z = inputs[:, :self.units]
- x_r = inputs[:, self.units:2 * self.units]
- x_h = inputs[:, 2 * self.units:]
- elif self.implementation == 1:
- x_z = K.dot(inputs * dp_mask[0], self.kernel_z)
- x_r = K.dot(inputs * dp_mask[1], self.kernel_r)
- x_h = K.dot(inputs * dp_mask[2], self.kernel_h)
- if self.use_bias:
- x_z = K.bias_add(x_z, self.bias_z)
- x_r = K.bias_add(x_r, self.bias_r)
- x_h = K.bias_add(x_h, self.bias_h)
- else:
- raise ValueError('Unknown `implementation` mode.')
- z = self.recurrent_activation(x_z + K.dot(h_tm1 * rec_dp_mask[0],
- self.recurrent_kernel_z))
- r = self.recurrent_activation(x_r + K.dot(h_tm1 * rec_dp_mask[1],
- self.recurrent_kernel_r))
-
- hh = self.activation(x_h + K.dot(r * h_tm1 * rec_dp_mask[2],
- self.recurrent_kernel_h))
h = z * h_tm1 + (1 - z) * hh
if 0 < self.dropout + self.recurrent_dropout:
- h._uses_learning_phase = True
+ if training is None:
+ h._uses_learning_phase = True
return h, [h]
+
+class GRU(RNN):
+ # pylint: disable=line-too-long
+ """Gated Recurrent Unit - Cho et al.
+
+ 2014.
+
+ Arguments:
+ units: Positive integer, dimensionality of the output space.
+ activation: Activation function to use
+ (see [activations](../activations.md)).
+ If you pass None, no activation is applied
+ (ie. "linear" activation: `a(x) = x`).
+ recurrent_activation: Activation function to use
+ for the recurrent step
+ (see [activations](../activations.md)).
+ use_bias: Boolean, whether the layer uses a bias vector.
+ kernel_initializer: Initializer for the `kernel` weights matrix,
+ used for the linear transformation of the inputs.
+ (see [initializers](../initializers.md)).
+ recurrent_initializer: Initializer for the `recurrent_kernel`
+ weights matrix,
+ used for the linear transformation of the recurrent state.
+ (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector
+ (see [initializers](../initializers.md)).
+ kernel_regularizer: Regularizer function applied to
+ the `kernel` weights matrix
+ (see [regularizer](../regularizers.md)).
+ recurrent_regularizer: Regularizer function applied to
+ the `recurrent_kernel` weights matrix
+ (see [regularizer](../regularizers.md)).
+ bias_regularizer: Regularizer function applied to the bias vector
+ (see [regularizer](../regularizers.md)).
+ activity_regularizer: Regularizer function applied to
+ the output of the layer (its "activation").
+ (see [regularizer](../regularizers.md)).
+ kernel_constraint: Constraint function applied to
+ the `kernel` weights matrix
+ (see [constraints](../constraints.md)).
+ recurrent_constraint: Constraint function applied to
+ the `recurrent_kernel` weights matrix
+ (see [constraints](../constraints.md)).
+ bias_constraint: Constraint function applied to the bias vector
+ (see [constraints](../constraints.md)).
+ dropout: Float between 0 and 1.
+ Fraction of the units to drop for
+ the linear transformation of the inputs.
+ recurrent_dropout: Float between 0 and 1.
+ Fraction of the units to drop for
+ the linear transformation of the recurrent state.
+ implementation: Implementation mode, either 1 or 2.
+ Mode 1 will structure its operations as a larger number of
+ smaller dot products and additions, whereas mode 2 will
+ batch them into fewer, larger operations. These modes will
+ have different performance profiles on different hardware and
+ for different applications.
+ return_sequences: Boolean. Whether to return the last output.
+ in the output sequence, or the full sequence.
+ return_state: Boolean. Whether to return the last state
+ in addition to the output.
+ go_backwards: Boolean (default False).
+ If True, process the input sequence backwards and return the
+ reversed sequence.
+ stateful: Boolean (default False). If True, the last state
+ for each sample at index i in a batch will be used as initial
+ state for the sample of index i in the following batch.
+ unroll: Boolean (default False).
+ If True, the network will be unrolled,
+ else a symbolic loop will be used.
+ Unrolling can speed-up a RNN,
+ although it tends to be more memory-intensive.
+ Unrolling is only suitable for short sequences.
+
+ References:
+ - [On the Properties of Neural Machine Translation: Encoder-Decoder Approaches](https://arxiv.org/abs/1409.1259)
+ - [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling](http://arxiv.org/abs/1412.3555v1)
+ - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287)
+ """
+ # pylint: enable=line-too-long
+
+ def __init__(self,
+ units,
+ activation='tanh',
+ recurrent_activation='hard_sigmoid',
+ use_bias=True,
+ kernel_initializer='glorot_uniform',
+ recurrent_initializer='orthogonal',
+ bias_initializer='zeros',
+ kernel_regularizer=None,
+ recurrent_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ kernel_constraint=None,
+ recurrent_constraint=None,
+ bias_constraint=None,
+ dropout=0.,
+ recurrent_dropout=0.,
+ implementation=1,
+ return_sequences=False,
+ return_state=False,
+ go_backwards=False,
+ stateful=False,
+ unroll=False,
+ **kwargs):
+ if implementation == 0:
+ logging.warning('`implementation=0` has been deprecated, '
+ 'and now defaults to `implementation=1`.'
+ 'Please update your layer call.')
+ cell = GRUCell(
+ units,
+ activation=activation,
+ recurrent_activation=recurrent_activation,
+ use_bias=use_bias,
+ kernel_initializer=kernel_initializer,
+ recurrent_initializer=recurrent_initializer,
+ bias_initializer=bias_initializer,
+ kernel_regularizer=kernel_regularizer,
+ recurrent_regularizer=recurrent_regularizer,
+ bias_regularizer=bias_regularizer,
+ kernel_constraint=kernel_constraint,
+ recurrent_constraint=recurrent_constraint,
+ bias_constraint=bias_constraint,
+ dropout=dropout,
+ recurrent_dropout=recurrent_dropout,
+ implementation=implementation)
+ super(GRU, self).__init__(
+ cell,
+ return_sequences=return_sequences,
+ return_state=return_state,
+ go_backwards=go_backwards,
+ stateful=stateful,
+ unroll=unroll,
+ **kwargs)
+ self.activity_regularizer = regularizers.get(activity_regularizer)
+
+ def call(self, inputs, mask=None, training=None, initial_state=None):
+ self.cell._generate_dropout_mask(inputs, training=training)
+ self.cell._generate_recurrent_dropout_mask(inputs, training=training)
+ return super(GRU, self).call(
+ inputs, mask=mask, training=training, initial_state=initial_state)
+
+ @property
+ def units(self):
+ return self.cell.units
+
+ @property
+ def activation(self):
+ return self.cell.activation
+
+ @property
+ def recurrent_activation(self):
+ return self.cell.recurrent_activation
+
+ @property
+ def use_bias(self):
+ return self.cell.use_bias
+
+ @property
+ def kernel_initializer(self):
+ return self.cell.kernel_initializer
+
+ @property
+ def recurrent_initializer(self):
+ return self.cell.recurrent_initializer
+
+ @property
+ def bias_initializer(self):
+ return self.cell.bias_initializer
+
+ @property
+ def kernel_regularizer(self):
+ return self.cell.kernel_regularizer
+
+ @property
+ def recurrent_regularizer(self):
+ return self.cell.recurrent_regularizer
+
+ @property
+ def bias_regularizer(self):
+ return self.cell.bias_regularizer
+
+ @property
+ def kernel_constraint(self):
+ return self.cell.kernel_constraint
+
+ @property
+ def recurrent_constraint(self):
+ return self.cell.recurrent_constraint
+
+ @property
+ def bias_constraint(self):
+ return self.cell.bias_constraint
+
+ @property
+ def dropout(self):
+ return self.cell.dropout
+
+ @property
+ def recurrent_dropout(self):
+ return self.cell.recurrent_dropout
+
+ @property
+ def implementation(self):
+ return self.cell.implementation
+
def get_config(self):
config = {
- 'units': self.units,
- 'activation': activations.serialize(self.activation),
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
'recurrent_activation':
activations.serialize(self.recurrent_activation),
- 'use_bias': self.use_bias,
- 'kernel_initializer': initializers.serialize(self.kernel_initializer),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
- 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint': constraints.serialize(self.kernel_constraint),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint),
- 'dropout': self.dropout,
- 'recurrent_dropout': self.recurrent_dropout
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout,
+ 'implementation':
+ self.implementation
}
base_config = super(GRU, self).get_config()
+ del base_config['cell']
return dict(list(base_config.items()) + list(config.items()))
+ @classmethod
+ def from_config(cls, config):
+ if 'implementation' in config and config['implementation'] == 0:
+ config['implementation'] = 1
+ return cls(**config)
-class LSTM(Recurrent):
- """Long-Short Term Memory unit - Hochreiter 1997.
- For a step-by-step description of the algorithm, see
- [this tutorial](http://deeplearning.net/tutorial/lstm.html).
+class LSTMCell(Layer):
+ """Cell class for the LSTM layer.
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use.
+ activation: Activation function to use
+ (see [activations](../activations.md)).
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
- for the recurrent step.
+ for the recurrent step
+ (see [activations](../activations.md)).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
- used for the linear transformation of the inputs..
+ used for the linear transformation of the inputs.
+ (see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
- used for the linear transformation of the recurrent state..
- bias_initializer: Initializer for the bias vector.
+ used for the linear transformation of the recurrent state.
+ (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector
+ (see [initializers](../initializers.md)).
unit_forget_bias: Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Setting it to true will also force `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et
al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix.
+ the `kernel` weights matrix
+ (see [regularizer](../regularizers.md)).
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix.
- bias_regularizer: Regularizer function applied to the bias vector.
- activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation")..
+ the `recurrent_kernel` weights matrix
+ (see [regularizer](../regularizers.md)).
+ bias_regularizer: Regularizer function applied to the bias vector
+ (see [regularizer](../regularizers.md)).
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix.
+ the `kernel` weights matrix
+ (see [constraints](../constraints.md)).
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix.
- bias_constraint: Constraint function applied to the bias vector.
+ the `recurrent_kernel` weights matrix
+ (see [constraints](../constraints.md)).
+ bias_constraint: Constraint function applied to the bias vector
+ (see [constraints](../constraints.md)).
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
-
- References:
- - [Long short-term
- memory]((http://www.bioinf.jku.at/publications/older/2604.pdf)
- (original 1997 paper)
- - [Supervised sequence labeling with recurrent neural
- networks](http://www.cs.toronto.edu/~graves/preprint.pdf)
- - [A Theoretically Grounded Application of Dropout in Recurrent Neural
- Networks](http://arxiv.org/abs/1512.05287)
+ implementation: Implementation mode, either 1 or 2.
+ Mode 1 will structure its operations as a larger number of
+ smaller dot products and additions, whereas mode 2 will
+ batch them into fewer, larger operations. These modes will
+ have different performance profiles on different hardware and
+ for different applications.
"""
def __init__(self,
@@ -1009,15 +1712,14 @@ class LSTM(Recurrent):
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
- activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
+ implementation=1,
**kwargs):
- super(LSTM, self).__init__(
- activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
+ super(LSTMCell, self).__init__(**kwargs)
self.units = units
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
@@ -1038,25 +1740,15 @@ class LSTM(Recurrent):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
- self.state_spec = [
- InputSpec(shape=(None, self.units)),
- InputSpec(shape=(None, self.units))
- ]
+ self.implementation = implementation
+ self.state_size = (self.units, self.units)
+ self._dropout_mask = None
+ self._recurrent_dropout_mask = None
def build(self, input_shape):
- if isinstance(input_shape, list):
- input_shape = input_shape[0]
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- batch_size = input_shape[0] if self.stateful else None
- self.input_dim = input_shape[2]
- self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim))
-
- self.states = [None, None]
- if self.stateful:
- self.reset_states()
-
+ input_dim = input_shape[-1]
self.kernel = self.add_weight(
- shape=(self.input_dim, self.units * 4),
+ shape=(input_dim, self.units * 4),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
@@ -1112,96 +1804,90 @@ class LSTM(Recurrent):
self.bias_o = None
self.built = True
- def preprocess_input(self, inputs, training=None):
- if self.implementation == 0:
- input_shape = inputs.get_shape().as_list()
- input_dim = input_shape[2]
- timesteps = input_shape[1]
-
- x_i = _time_distributed_dense(
- inputs,
- self.kernel_i,
- self.bias_i,
- self.dropout,
- input_dim,
- self.units,
- timesteps,
- training=training)
- x_f = _time_distributed_dense(
- inputs,
- self.kernel_f,
- self.bias_f,
- self.dropout,
- input_dim,
- self.units,
- timesteps,
- training=training)
- x_c = _time_distributed_dense(
- inputs,
- self.kernel_c,
- self.bias_c,
- self.dropout,
- input_dim,
- self.units,
- timesteps,
- training=training)
- x_o = _time_distributed_dense(
- inputs,
- self.kernel_o,
- self.bias_o,
- self.dropout,
- input_dim,
- self.units,
- timesteps,
- training=training)
- return K.concatenate([x_i, x_f, x_c, x_o], axis=2)
- else:
- return inputs
-
- def get_constants(self, inputs, training=None):
- constants = []
- if self.implementation != 0 and 0 < self.dropout < 1:
- input_shape = K.int_shape(inputs)
- input_dim = input_shape[-1]
- ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
- ones = K.tile(ones, (1, int(input_dim)))
+ def _generate_dropout_mask(self, inputs, training=None):
+ if 0 < self.dropout < 1:
+ ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
def dropped_inputs():
return K.dropout(ones, self.dropout)
- dp_mask = [
+ self._dropout_mask = [
K.in_train_phase(dropped_inputs, ones, training=training)
for _ in range(4)
]
- constants.append(dp_mask)
else:
- constants.append([K.cast_to_floatx(1.) for _ in range(4)])
+ self._dropout_mask = None
+ def _generate_recurrent_dropout_mask(self, inputs, training=None):
if 0 < self.recurrent_dropout < 1:
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
ones = K.tile(ones, (1, self.units))
- def dropped_inputs(): # pylint: disable=function-redefined
- return K.dropout(ones, self.recurrent_dropout)
+ def dropped_inputs():
+ return K.dropout(ones, self.dropout)
- rec_dp_mask = [
+ self._recurrent_dropout_mask = [
K.in_train_phase(dropped_inputs, ones, training=training)
for _ in range(4)
]
- constants.append(rec_dp_mask)
else:
- constants.append([K.cast_to_floatx(1.) for _ in range(4)])
- return constants
-
- def step(self, inputs, states):
- h_tm1 = states[0]
- c_tm1 = states[1]
- dp_mask = states[2]
- rec_dp_mask = states[3]
-
- if self.implementation == 2:
- z = K.dot(inputs * dp_mask[0], self.kernel)
- z += K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel)
+ self._recurrent_dropout_mask = None
+
+ def call(self, inputs, states, training=None):
+ # dropout matrices for input units
+ dp_mask = self._dropout_mask
+ # dropout matrices for recurrent units
+ rec_dp_mask = self._recurrent_dropout_mask
+
+ h_tm1 = states[0] # previous memory state
+ c_tm1 = states[1] # previous carry state
+
+ if self.implementation == 1:
+ if 0 < self.dropout < 1.:
+ inputs_i = inputs * dp_mask[0]
+ inputs_f = inputs * dp_mask[1]
+ inputs_c = inputs * dp_mask[2]
+ inputs_o = inputs * dp_mask[3]
+ else:
+ inputs_i = inputs
+ inputs_f = inputs
+ inputs_c = inputs
+ inputs_o = inputs
+ x_i = K.dot(inputs_i, self.kernel_i)
+ x_f = K.dot(inputs_f, self.kernel_f)
+ x_c = K.dot(inputs_c, self.kernel_c)
+ x_o = K.dot(inputs_o, self.kernel_o)
+ if self.use_bias:
+ x_i = K.bias_add(x_i, self.bias_i)
+ x_f = K.bias_add(x_f, self.bias_f)
+ x_c = K.bias_add(x_c, self.bias_c)
+ x_o = K.bias_add(x_o, self.bias_o)
+
+ if 0 < self.recurrent_dropout < 1.:
+ h_tm1_i = h_tm1 * rec_dp_mask[0]
+ h_tm1_f = h_tm1 * rec_dp_mask[1]
+ h_tm1_c = h_tm1 * rec_dp_mask[2]
+ h_tm1_o = h_tm1 * rec_dp_mask[3]
+ else:
+ h_tm1_i = h_tm1
+ h_tm1_f = h_tm1
+ h_tm1_c = h_tm1
+ h_tm1_o = h_tm1
+ i = self.recurrent_activation(
+ x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))
+ f = self.recurrent_activation(
+ x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))
+ c = f * c_tm1 + i * self.activation(
+ x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))
+ o = self.recurrent_activation(
+ x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))
+ else:
+ if 0. < self.dropout < 1.:
+ inputs *= dp_mask[0]
+ z = K.dot(inputs, self.kernel)
+ if 0. < self.recurrent_dropout < 1.:
+ h_tm1 *= rec_dp_mask[0]
+ z += K.dot(h_tm1, self.recurrent_kernel)
if self.use_bias:
z = K.bias_add(z, self.bias)
@@ -1214,57 +1900,606 @@ class LSTM(Recurrent):
f = self.recurrent_activation(z1)
c = f * c_tm1 + i * self.activation(z2)
o = self.recurrent_activation(z3)
- else:
- if self.implementation == 0:
- x_i = inputs[:, :self.units]
- x_f = inputs[:, self.units:2 * self.units]
- x_c = inputs[:, 2 * self.units:3 * self.units]
- x_o = inputs[:, 3 * self.units:]
- elif self.implementation == 1:
- x_i = K.dot(inputs * dp_mask[0], self.kernel_i) + self.bias_i
- x_f = K.dot(inputs * dp_mask[1], self.kernel_f) + self.bias_f
- x_c = K.dot(inputs * dp_mask[2], self.kernel_c) + self.bias_c
- x_o = K.dot(inputs * dp_mask[3], self.kernel_o) + self.bias_o
- else:
- raise ValueError('Unknown `implementation` mode.')
- i = self.recurrent_activation(x_i + K.dot(h_tm1 * rec_dp_mask[0],
- self.recurrent_kernel_i))
- f = self.recurrent_activation(x_f + K.dot(h_tm1 * rec_dp_mask[1],
- self.recurrent_kernel_f))
- c = f * c_tm1 + i * self.activation(
- x_c + K.dot(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c))
- o = self.recurrent_activation(x_o + K.dot(h_tm1 * rec_dp_mask[3],
- self.recurrent_kernel_o))
h = o * self.activation(c)
if 0 < self.dropout + self.recurrent_dropout:
- h._uses_learning_phase = True
+ if training is None:
+ h._uses_learning_phase = True
return h, [h, c]
+
+class LSTM(RNN):
+ # pylint: disable=line-too-long
+ """Long-Short Term Memory layer - Hochreiter 1997.
+
+ Arguments:
+ units: Positive integer, dimensionality of the output space.
+ activation: Activation function to use
+ (see [activations](../activations.md)).
+ If you pass None, no activation is applied
+ (ie. "linear" activation: `a(x) = x`).
+ recurrent_activation: Activation function to use
+ for the recurrent step
+ (see [activations](../activations.md)).
+ use_bias: Boolean, whether the layer uses a bias vector.
+ kernel_initializer: Initializer for the `kernel` weights matrix,
+ used for the linear transformation of the inputs.
+ (see [initializers](../initializers.md)).
+ recurrent_initializer: Initializer for the `recurrent_kernel`
+ weights matrix,
+ used for the linear transformation of the recurrent state.
+ (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector
+ (see [initializers](../initializers.md)).
+ unit_forget_bias: Boolean.
+ If True, add 1 to the bias of the forget gate at initialization.
+ Setting it to true will also force `bias_initializer="zeros"`.
+ This is recommended in [Jozefowicz et
+ al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
+ kernel_regularizer: Regularizer function applied to
+ the `kernel` weights matrix
+ (see [regularizer](../regularizers.md)).
+ recurrent_regularizer: Regularizer function applied to
+ the `recurrent_kernel` weights matrix
+ (see [regularizer](../regularizers.md)).
+ bias_regularizer: Regularizer function applied to the bias vector
+ (see [regularizer](../regularizers.md)).
+ activity_regularizer: Regularizer function applied to
+ the output of the layer (its "activation").
+ (see [regularizer](../regularizers.md)).
+ kernel_constraint: Constraint function applied to
+ the `kernel` weights matrix
+ (see [constraints](../constraints.md)).
+ recurrent_constraint: Constraint function applied to
+ the `recurrent_kernel` weights matrix
+ (see [constraints](../constraints.md)).
+ bias_constraint: Constraint function applied to the bias vector
+ (see [constraints](../constraints.md)).
+ dropout: Float between 0 and 1.
+ Fraction of the units to drop for
+ the linear transformation of the inputs.
+ recurrent_dropout: Float between 0 and 1.
+ Fraction of the units to drop for
+ the linear transformation of the recurrent state.
+ implementation: Implementation mode, either 1 or 2.
+ Mode 1 will structure its operations as a larger number of
+ smaller dot products and additions, whereas mode 2 will
+ batch them into fewer, larger operations. These modes will
+ have different performance profiles on different hardware and
+ for different applications.
+ return_sequences: Boolean. Whether to return the last output.
+ in the output sequence, or the full sequence.
+ return_state: Boolean. Whether to return the last state
+ in addition to the output.
+ go_backwards: Boolean (default False).
+ If True, process the input sequence backwards and return the
+ reversed sequence.
+ stateful: Boolean (default False). If True, the last state
+ for each sample at index i in a batch will be used as initial
+ state for the sample of index i in the following batch.
+ unroll: Boolean (default False).
+ If True, the network will be unrolled,
+ else a symbolic loop will be used.
+ Unrolling can speed-up a RNN,
+ although it tends to be more memory-intensive.
+ Unrolling is only suitable for short sequences.
+
+ References:
+ - [Long short-term memory](http://www.bioinf.jku.at/publications/older/2604.pdf)
+ - [Learning to forget: Continual prediction with LSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015)
+ - [Supervised sequence labeling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf)
+ - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287)
+ """
+ # pylint: enable=line-too-long
+
+ def __init__(self,
+ units,
+ activation='tanh',
+ recurrent_activation='hard_sigmoid',
+ use_bias=True,
+ kernel_initializer='glorot_uniform',
+ recurrent_initializer='orthogonal',
+ bias_initializer='zeros',
+ unit_forget_bias=True,
+ kernel_regularizer=None,
+ recurrent_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ kernel_constraint=None,
+ recurrent_constraint=None,
+ bias_constraint=None,
+ dropout=0.,
+ recurrent_dropout=0.,
+ implementation=1,
+ return_sequences=False,
+ return_state=False,
+ go_backwards=False,
+ stateful=False,
+ unroll=False,
+ **kwargs):
+ if implementation == 0:
+ logging.warning('`implementation=0` has been deprecated, '
+ 'and now defaults to `implementation=1`.'
+ 'Please update your layer call.')
+ cell = LSTMCell(
+ units,
+ activation=activation,
+ recurrent_activation=recurrent_activation,
+ use_bias=use_bias,
+ kernel_initializer=kernel_initializer,
+ recurrent_initializer=recurrent_initializer,
+ unit_forget_bias=unit_forget_bias,
+ bias_initializer=bias_initializer,
+ kernel_regularizer=kernel_regularizer,
+ recurrent_regularizer=recurrent_regularizer,
+ bias_regularizer=bias_regularizer,
+ kernel_constraint=kernel_constraint,
+ recurrent_constraint=recurrent_constraint,
+ bias_constraint=bias_constraint,
+ dropout=dropout,
+ recurrent_dropout=recurrent_dropout,
+ implementation=implementation)
+ super(LSTM, self).__init__(
+ cell,
+ return_sequences=return_sequences,
+ return_state=return_state,
+ go_backwards=go_backwards,
+ stateful=stateful,
+ unroll=unroll,
+ **kwargs)
+ self.activity_regularizer = regularizers.get(activity_regularizer)
+
+ def call(self, inputs, mask=None, training=None, initial_state=None):
+ self.cell._generate_dropout_mask(inputs, training=training)
+ self.cell._generate_recurrent_dropout_mask(inputs, training=training)
+ return super(LSTM, self).call(
+ inputs, mask=mask, training=training, initial_state=initial_state)
+
+ @property
+ def units(self):
+ return self.cell.units
+
+ @property
+ def activation(self):
+ return self.cell.activation
+
+ @property
+ def recurrent_activation(self):
+ return self.cell.recurrent_activation
+
+ @property
+ def use_bias(self):
+ return self.cell.use_bias
+
+ @property
+ def kernel_initializer(self):
+ return self.cell.kernel_initializer
+
+ @property
+ def recurrent_initializer(self):
+ return self.cell.recurrent_initializer
+
+ @property
+ def bias_initializer(self):
+ return self.cell.bias_initializer
+
+ @property
+ def unit_forget_bias(self):
+ return self.cell.unit_forget_bias
+
+ @property
+ def kernel_regularizer(self):
+ return self.cell.kernel_regularizer
+
+ @property
+ def recurrent_regularizer(self):
+ return self.cell.recurrent_regularizer
+
+ @property
+ def bias_regularizer(self):
+ return self.cell.bias_regularizer
+
+ @property
+ def kernel_constraint(self):
+ return self.cell.kernel_constraint
+
+ @property
+ def recurrent_constraint(self):
+ return self.cell.recurrent_constraint
+
+ @property
+ def bias_constraint(self):
+ return self.cell.bias_constraint
+
+ @property
+ def dropout(self):
+ return self.cell.dropout
+
+ @property
+ def recurrent_dropout(self):
+ return self.cell.recurrent_dropout
+
+ @property
+ def implementation(self):
+ return self.cell.implementation
+
def get_config(self):
config = {
- 'units': self.units,
- 'activation': activations.serialize(self.activation),
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
'recurrent_activation':
activations.serialize(self.recurrent_activation),
- 'use_bias': self.use_bias,
- 'kernel_initializer': initializers.serialize(self.kernel_initializer),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
- 'unit_forget_bias': self.unit_forget_bias,
- 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'unit_forget_bias':
+ self.unit_forget_bias,
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint': constraints.serialize(self.kernel_constraint),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint),
- 'dropout': self.dropout,
- 'recurrent_dropout': self.recurrent_dropout
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout,
+ 'implementation':
+ self.implementation
}
base_config = super(LSTM, self).get_config()
+ del base_config['cell']
+ return dict(list(base_config.items()) + list(config.items()))
+
+ @classmethod
+ def from_config(cls, config):
+ if 'implementation' in config and config['implementation'] == 0:
+ config['implementation'] = 1
+ return cls(**config)
+
+
+class Recurrent(Layer):
+ """Deprecated abstract base class for recurrent layers.
+
+ It still exists because it is leveraged by the convolutional-recurrent layers.
+ It will be removed entirely in the future.
+ It was never part of the public API.
+ Do not use.
+
+ Arguments:
+ weights: list of Numpy arrays to set as initial weights.
+ The list should have 3 elements, of shapes:
+ `[(input_dim, output_dim), (output_dim, output_dim), (output_dim,)]`.
+ return_sequences: Boolean. Whether to return the last output
+ in the output sequence, or the full sequence.
+ return_state: Boolean. Whether to return the last state
+ in addition to the output.
+ go_backwards: Boolean (default False).
+ If True, process the input sequence backwards and return the
+ reversed sequence.
+ stateful: Boolean (default False). If True, the last state
+ for each sample at index i in a batch will be used as initial
+ state for the sample of index i in the following batch.
+ unroll: Boolean (default False).
+ If True, the network will be unrolled,
+ else a symbolic loop will be used.
+ Unrolling can speed-up a RNN,
+ although it tends to be more memory-intensive.
+ Unrolling is only suitable for short sequences.
+ implementation: one of {0, 1, or 2}.
+ If set to 0, the RNN will use
+ an implementation that uses fewer, larger matrix products,
+ thus running faster on CPU but consuming more memory.
+ If set to 1, the RNN will use more matrix products,
+ but smaller ones, thus running slower
+ (may actually be faster on GPU) while consuming less memory.
+ If set to 2 (LSTM/GRU only),
+ the RNN will combine the input gate,
+ the forget gate and the output gate into a single matrix,
+ enabling more time-efficient parallelization on the GPU.
+ Note: RNN dropout must be shared for all gates,
+ resulting in a slightly reduced regularization.
+ input_dim: dimensionality of the input (integer).
+ This argument (or alternatively, the keyword argument `input_shape`)
+ is required when using this layer as the first layer in a model.
+ input_length: Length of input sequences, to be specified
+ when it is constant.
+ This argument is required if you are going to connect
+ `Flatten` then `Dense` layers upstream
+ (without it, the shape of the dense outputs cannot be computed).
+ Note that if the recurrent layer is not the first layer
+ in your model, you would need to specify the input length
+ at the level of the first layer
+ (e.g. via the `input_shape` argument)
+
+ Input shape:
+ 3D tensor with shape `(batch_size, timesteps, input_dim)`,
+ (Optional) 2D tensors with shape `(batch_size, output_dim)`.
+
+ Output shape:
+ - if `return_state`: a list of tensors. The first tensor is
+ the output. The remaining tensors are the last states,
+ each with shape `(batch_size, units)`.
+ - if `return_sequences`: 3D tensor with shape
+ `(batch_size, timesteps, units)`.
+ - else, 2D tensor with shape `(batch_size, units)`.
+
+ # Masking
+ This layer supports masking for input data with a variable number
+ of timesteps. To introduce masks to your data,
+ use an `Embedding` layer with the `mask_zero` parameter
+ set to `True`.
+
+ # Note on using statefulness in RNNs
+ You can set RNN layers to be 'stateful', which means that the states
+ computed for the samples in one batch will be reused as initial states
+ for the samples in the next batch. This assumes a one-to-one mapping
+ between samples in different successive batches.
+
+ To enable statefulness:
+ - specify `stateful=True` in the layer constructor.
+ - specify a fixed batch size for your model, by passing
+ if sequential model:
+ `batch_input_shape=(...)` to the first layer in your model.
+ else for functional model with 1 or more Input layers:
+ `batch_shape=(...)` to all the first layers in your model.
+ This is the expected shape of your inputs
+ *including the batch size*.
+ It should be a tuple of integers, e.g. `(32, 10, 100)`.
+ - specify `shuffle=False` when calling fit().
+
+ To reset the states of your model, call `.reset_states()` on either
+ a specific layer, or on your entire model.
+
+ # Note on specifying the initial state of RNNs
+ You can specify the initial state of RNN layers symbolically by
+ calling them with the keyword argument `initial_state`. The value of
+ `initial_state` should be a tensor or list of tensors representing
+ the initial state of the RNN layer.
+
+ You can specify the initial state of RNN layers numerically by
+ calling `reset_states` with the keyword argument `states`. The value of
+ `states` should be a numpy array or list of numpy arrays representing
+ the initial state of the RNN layer.
+ """
+
+ def __init__(self,
+ return_sequences=False,
+ return_state=False,
+ go_backwards=False,
+ stateful=False,
+ unroll=False,
+ implementation=0,
+ **kwargs):
+ super(Recurrent, self).__init__(**kwargs)
+ self.return_sequences = return_sequences
+ self.return_state = return_state
+ self.go_backwards = go_backwards
+ self.stateful = stateful
+ self.unroll = unroll
+ self.implementation = implementation
+ self.supports_masking = True
+ self.input_spec = [InputSpec(ndim=3)]
+ self.state_spec = None
+ self.dropout = 0
+ self.recurrent_dropout = 0
+
+ def _compute_output_shape(self, input_shape):
+ if isinstance(input_shape, list):
+ input_shape = input_shape[0]
+ input_shape = tensor_shape.TensorShape(input_shape).as_list()
+ if self.return_sequences:
+ output_shape = (input_shape[0], input_shape[1], self.units)
+ else:
+ output_shape = (input_shape[0], self.units)
+
+ if self.return_state:
+ state_shape = [tensor_shape.TensorShape(
+ (input_shape[0], self.units)) for _ in self.states]
+ return [tensor_shape.TensorShape(output_shape)] + state_shape
+ return tensor_shape.TensorShape(output_shape)
+
+ def compute_mask(self, inputs, mask):
+ if isinstance(mask, list):
+ mask = mask[0]
+ output_mask = mask if self.return_sequences else None
+ if self.return_state:
+ state_mask = [None for _ in self.states]
+ return [output_mask] + state_mask
+ return output_mask
+
+ def step(self, inputs, states):
+ raise NotImplementedError
+
+ def get_constants(self, inputs, training=None):
+ return []
+
+ def get_initial_state(self, inputs):
+ # build an all-zero tensor of shape (samples, output_dim)
+ initial_state = K.zeros_like(inputs) # (samples, timesteps, input_dim)
+ initial_state = K.sum(initial_state, axis=(1, 2)) # (samples,)
+ initial_state = K.expand_dims(initial_state) # (samples, 1)
+ initial_state = K.tile(initial_state, [1,
+ self.units]) # (samples, output_dim)
+ initial_state = [initial_state for _ in range(len(self.states))]
+ return initial_state
+
+ def preprocess_input(self, inputs, training=None):
+ return inputs
+
+ def __call__(self, inputs, initial_state=None, **kwargs):
+ if (isinstance(inputs, (list, tuple)) and
+ len(inputs) > 1
+ and initial_state is None):
+ initial_state = inputs[1:]
+ inputs = inputs[0]
+
+ # If `initial_state` is specified,
+ # and if it a Keras tensor,
+ # then add it to the inputs and temporarily
+ # modify the input spec to include the state.
+ if initial_state is None:
+ return super(Recurrent, self).__call__(inputs, **kwargs)
+
+ if not isinstance(initial_state, (list, tuple)):
+ initial_state = [initial_state]
+
+ is_keras_tensor = hasattr(initial_state[0], '_keras_history')
+ for tensor in initial_state:
+ if hasattr(tensor, '_keras_history') != is_keras_tensor:
+ raise ValueError('The initial state of an RNN layer cannot be'
+ ' specified with a mix of Keras tensors and'
+ ' non-Keras tensors')
+
+ if is_keras_tensor:
+ # Compute the full input spec, including state
+ input_spec = self.input_spec
+ state_spec = self.state_spec
+ if not isinstance(input_spec, list):
+ input_spec = [input_spec]
+ if not isinstance(state_spec, list):
+ state_spec = [state_spec]
+ self.input_spec = input_spec + state_spec
+
+ # Compute the full inputs, including state
+ inputs = [inputs] + list(initial_state)
+
+ # Perform the call
+ output = super(Recurrent, self).__call__(inputs, **kwargs)
+
+ # Restore original input spec
+ self.input_spec = input_spec
+ return output
+ else:
+ kwargs['initial_state'] = initial_state
+ return super(Recurrent, self).__call__(inputs, **kwargs)
+
+ def call(self, inputs, mask=None, training=None, initial_state=None):
+ # input shape: `(samples, time (padded with zeros), input_dim)`
+ # note that the .build() method of subclasses MUST define
+ # self.input_spec and self.state_spec with complete input shapes.
+ if isinstance(inputs, list):
+ initial_state = inputs[1:]
+ inputs = inputs[0]
+ elif initial_state is not None:
+ pass
+ elif self.stateful:
+ initial_state = self.states
+ else:
+ initial_state = self.get_initial_state(inputs)
+
+ if isinstance(mask, list):
+ mask = mask[0]
+
+ if len(initial_state) != len(self.states):
+ raise ValueError('Layer has ' + str(len(self.states)) +
+ ' states but was passed ' + str(len(initial_state)) +
+ ' initial states.')
+ input_shape = K.int_shape(inputs)
+ if self.unroll and input_shape[1] is None:
+ raise ValueError('Cannot unroll a RNN if the '
+ 'time dimension is undefined. \n'
+ '- If using a Sequential model, '
+ 'specify the time dimension by passing '
+ 'an `input_shape` or `batch_input_shape` '
+ 'argument to your first layer. If your '
+ 'first layer is an Embedding, you can '
+ 'also use the `input_length` argument.\n'
+ '- If using the functional API, specify '
+ 'the time dimension by passing a `shape` '
+ 'or `batch_shape` argument to your Input layer.')
+ constants = self.get_constants(inputs, training=None)
+ preprocessed_input = self.preprocess_input(inputs, training=None)
+ last_output, outputs, states = K.rnn(
+ self.step,
+ preprocessed_input,
+ initial_state,
+ go_backwards=self.go_backwards,
+ mask=mask,
+ constants=constants,
+ unroll=self.unroll)
+ if self.stateful:
+ updates = []
+ for i in range(len(states)):
+ updates.append((self.states[i], states[i]))
+ self.add_update(updates, inputs)
+
+ # Properly set learning phase
+ if 0 < self.dropout + self.recurrent_dropout:
+ last_output._uses_learning_phase = True
+ outputs._uses_learning_phase = True
+
+ if not self.return_sequences:
+ outputs = last_output
+
+ if self.return_state:
+ if not isinstance(states, (list, tuple)):
+ states = [states]
+ else:
+ states = list(states)
+ return [outputs] + states
+ return outputs
+
+ def reset_states(self, states=None):
+ if not self.stateful:
+ raise AttributeError('Layer must be stateful.')
+ batch_size = self.input_spec[0].shape[0]
+ if not batch_size:
+ raise ValueError('If a RNN is stateful, it needs to know '
+ 'its batch size. Specify the batch size '
+ 'of your input tensors: \n'
+ '- If using a Sequential model, '
+ 'specify the batch size by passing '
+ 'a `batch_input_shape` '
+ 'argument to your first layer.\n'
+ '- If using the functional API, specify '
+ 'the time dimension by passing a '
+ '`batch_shape` argument to your Input layer.')
+ # initialize state if None
+ if self.states[0] is None:
+ self.states = [K.zeros((batch_size, self.units)) for _ in self.states]
+ elif states is None:
+ for state in self.states:
+ K.set_value(state, np.zeros((batch_size, self.units)))
+ else:
+ if not isinstance(states, (list, tuple)):
+ states = [states]
+ if len(states) != len(self.states):
+ raise ValueError('Layer ' + self.name + ' expects ' +
+ str(len(self.states)) + ' states, '
+ 'but it received ' + str(len(states)) +
+ ' state values. Input received: ' + str(states))
+ for index, (value, state) in enumerate(zip(states, self.states)):
+ if value.shape != (batch_size, self.units):
+ raise ValueError('State ' + str(index) +
+ ' is incompatible with layer ' + self.name +
+ ': expected shape=' + str((batch_size, self.units)) +
+ ', found shape=' + str(value.shape))
+ K.set_value(state, value)
+
+ def get_config(self):
+ config = {
+ 'return_sequences': self.return_sequences,
+ 'return_state': self.return_state,
+ 'go_backwards': self.go_backwards,
+ 'stateful': self.stateful,
+ 'unroll': self.unroll,
+ 'implementation': self.implementation
+ }
+ base_config = super(Recurrent, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
new file mode 100644
index 0000000000..b1f89a30bb
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
@@ -0,0 +1,378 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for recurrent layers functionality other than GRU, LSTM, SimpleRNN.
+
+See also: lstm_test.py, gru_test.py, simplernn_test.py.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.platform import test
+
+
+class RNNTest(test.TestCase):
+
+ def test_minimal_rnn_cell_non_layer(self):
+
+ class MinimalRNNCell(object):
+
+ def __init__(self, units, input_dim):
+ self.units = units
+ self.state_size = units
+ self.kernel = keras.backend.variable(
+ np.random.random((input_dim, units)))
+
+ def call(self, inputs, states):
+ prev_output = states[0]
+ output = keras.backend.dot(inputs, self.kernel) + prev_output
+ return output, [output]
+
+ with self.test_session():
+ # Basic test case.
+ cell = MinimalRNNCell(32, 5)
+ x = keras.Input((None, 5))
+ layer = keras.layers.RNN(cell)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
+
+ # Test stacking.
+ cells = [MinimalRNNCell(8, 5),
+ MinimalRNNCell(32, 8),
+ MinimalRNNCell(32, 32)]
+ layer = keras.layers.RNN(cells)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
+
+ def test_minimal_rnn_cell_non_layer_multiple_states(self):
+
+ class MinimalRNNCell(object):
+
+ def __init__(self, units, input_dim):
+ self.units = units
+ self.state_size = (units, units)
+ self.kernel = keras.backend.variable(
+ np.random.random((input_dim, units)))
+
+ def call(self, inputs, states):
+ prev_output_1 = states[0]
+ prev_output_2 = states[1]
+ output = keras.backend.dot(inputs, self.kernel)
+ output += prev_output_1
+ output -= prev_output_2
+ return output, [output * 2, output * 3]
+
+ with self.test_session():
+ # Basic test case.
+ cell = MinimalRNNCell(32, 5)
+ x = keras.Input((None, 5))
+ layer = keras.layers.RNN(cell)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
+
+ # Test stacking.
+ cells = [MinimalRNNCell(8, 5),
+ MinimalRNNCell(16, 8),
+ MinimalRNNCell(32, 16)]
+ layer = keras.layers.RNN(cells)
+ assert layer.cell.state_size == (32, 32, 16, 16, 8, 8)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
+
+ def test_minimal_rnn_cell_layer(self):
+
+ class MinimalRNNCell(keras.layers.Layer):
+
+ def __init__(self, units, **kwargs):
+ self.units = units
+ self.state_size = units
+ super(MinimalRNNCell, self).__init__(**kwargs)
+
+ def build(self, input_shape):
+ self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
+ initializer='uniform',
+ name='kernel')
+ self.recurrent_kernel = self.add_weight(
+ shape=(self.units, self.units),
+ initializer='uniform',
+ name='recurrent_kernel')
+ self.built = True
+
+ def call(self, inputs, states):
+ prev_output = states[0]
+ h = keras.backend.dot(inputs, self.kernel)
+ output = h + keras.backend.dot(prev_output, self.recurrent_kernel)
+ return output, [output]
+
+ def get_config(self):
+ config = {'units': self.units}
+ base_config = super(MinimalRNNCell, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ with self.test_session():
+ # Test basic case.
+ x = keras.Input((None, 5))
+ cell = MinimalRNNCell(32)
+ layer = keras.layers.RNN(cell)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
+
+ # Test basic case serialization.
+ x_np = np.random.random((6, 5, 5))
+ y_np = model.predict(x_np)
+ weights = model.get_weights()
+ config = layer.get_config()
+ with keras.utils.CustomObjectScope({'MinimalRNNCell': MinimalRNNCell}):
+ layer = keras.layers.RNN.from_config(config)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.set_weights(weights)
+ y_np_2 = model.predict(x_np)
+ self.assertAllClose(y_np, y_np_2, atol=1e-4)
+
+ # Test stacking.
+ cells = [MinimalRNNCell(8),
+ MinimalRNNCell(12),
+ MinimalRNNCell(32)]
+ layer = keras.layers.RNN(cells)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
+
+ # Test stacked RNN serialization.
+ x_np = np.random.random((6, 5, 5))
+ y_np = model.predict(x_np)
+ weights = model.get_weights()
+ config = layer.get_config()
+ with keras.utils.CustomObjectScope({'MinimalRNNCell': MinimalRNNCell}):
+ layer = keras.layers.RNN.from_config(config)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.set_weights(weights)
+ y_np_2 = model.predict(x_np)
+ self.assertAllClose(y_np, y_np_2, atol=1e-4)
+
+ def test_rnn_cell_with_constants_layer(self):
+
+ class RNNCellWithConstants(keras.layers.Layer):
+
+ def __init__(self, units, **kwargs):
+ self.units = units
+ self.state_size = units
+ super(RNNCellWithConstants, self).__init__(**kwargs)
+
+ def build(self, input_shape):
+ if not isinstance(input_shape, list):
+ raise TypeError('expects constants shape')
+ [input_shape, constant_shape] = input_shape
+ # will (and should) raise if more than one constant passed
+
+ self.input_kernel = self.add_weight(
+ shape=(input_shape[-1], self.units),
+ initializer='uniform',
+ name='kernel')
+ self.recurrent_kernel = self.add_weight(
+ shape=(self.units, self.units),
+ initializer='uniform',
+ name='recurrent_kernel')
+ self.constant_kernel = self.add_weight(
+ shape=(constant_shape[-1], self.units),
+ initializer='uniform',
+ name='constant_kernel')
+ self.built = True
+
+ def call(self, inputs, states, constants):
+ [prev_output] = states
+ [constant] = constants
+ h_input = keras.backend.dot(inputs, self.input_kernel)
+ h_state = keras.backend.dot(prev_output, self.recurrent_kernel)
+ h_const = keras.backend.dot(constant, self.constant_kernel)
+ output = h_input + h_state + h_const
+ return output, [output]
+
+ def get_config(self):
+ config = {'units': self.units}
+ base_config = super(RNNCellWithConstants, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ with self.test_session():
+ # Test basic case.
+ x = keras.Input((None, 5))
+ c = keras.Input((3,))
+ cell = RNNCellWithConstants(32)
+ layer = keras.layers.RNN(cell)
+ y = layer(x, constants=c)
+ model = keras.models.Model([x, c], y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ [np.zeros((6, 5, 5)), np.zeros((6, 3))],
+ np.zeros((6, 32))
+ )
+
+ with self.test_session():
+ # Test basic case serialization.
+ x_np = np.random.random((6, 5, 5))
+ c_np = np.random.random((6, 3))
+ y_np = model.predict([x_np, c_np])
+ weights = model.get_weights()
+ config = layer.get_config()
+ custom_objects = {'RNNCellWithConstants': RNNCellWithConstants}
+ with keras.utils.CustomObjectScope(custom_objects):
+ layer = keras.layers.RNN.from_config(config.copy())
+ y = layer(x, constants=c)
+ model = keras.models.Model([x, c], y)
+ model.set_weights(weights)
+ y_np_2 = model.predict([x_np, c_np])
+ self.assertAllClose(y_np, y_np_2, atol=1e-4)
+
+ with self.test_session():
+ # test flat list inputs
+ with keras.utils.CustomObjectScope(custom_objects):
+ layer = keras.layers.RNN.from_config(config.copy())
+ y = layer([x, c])
+ model = keras.models.Model([x, c], y)
+ model.set_weights(weights)
+ y_np_3 = model.predict([x_np, c_np])
+ self.assertAllClose(y_np, y_np_3, atol=1e-4)
+
+ def test_rnn_cell_with_constants_layer_passing_initial_state(self):
+
+ class RNNCellWithConstants(keras.layers.Layer):
+
+ def __init__(self, units, **kwargs):
+ self.units = units
+ self.state_size = units
+ super(RNNCellWithConstants, self).__init__(**kwargs)
+
+ def build(self, input_shape):
+ if not isinstance(input_shape, list):
+ raise TypeError('expects constants shape')
+ [input_shape, constant_shape] = input_shape
+ # will (and should) raise if more than one constant passed
+
+ self.input_kernel = self.add_weight(
+ shape=(input_shape[-1], self.units),
+ initializer='uniform',
+ name='kernel')
+ self.recurrent_kernel = self.add_weight(
+ shape=(self.units, self.units),
+ initializer='uniform',
+ name='recurrent_kernel')
+ self.constant_kernel = self.add_weight(
+ shape=(constant_shape[-1], self.units),
+ initializer='uniform',
+ name='constant_kernel')
+ self.built = True
+
+ def call(self, inputs, states, constants):
+ [prev_output] = states
+ [constant] = constants
+ h_input = keras.backend.dot(inputs, self.input_kernel)
+ h_state = keras.backend.dot(prev_output, self.recurrent_kernel)
+ h_const = keras.backend.dot(constant, self.constant_kernel)
+ output = h_input + h_state + h_const
+ return output, [output]
+
+ def get_config(self):
+ config = {'units': self.units}
+ base_config = super(RNNCellWithConstants, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ with self.test_session():
+ # Test basic case.
+ x = keras.Input((None, 5))
+ c = keras.Input((3,))
+ s = keras.Input((32,))
+ cell = RNNCellWithConstants(32)
+ layer = keras.layers.RNN(cell)
+ y = layer(x, initial_state=s, constants=c)
+ model = keras.models.Model([x, s, c], y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ [np.zeros((6, 5, 5)), np.zeros((6, 32)), np.zeros((6, 3))],
+ np.zeros((6, 32))
+ )
+
+ with self.test_session():
+ # Test basic case serialization.
+ x_np = np.random.random((6, 5, 5))
+ s_np = np.random.random((6, 32))
+ c_np = np.random.random((6, 3))
+ y_np = model.predict([x_np, s_np, c_np])
+ weights = model.get_weights()
+ config = layer.get_config()
+ custom_objects = {'RNNCellWithConstants': RNNCellWithConstants}
+ with keras.utils.CustomObjectScope(custom_objects):
+ layer = keras.layers.RNN.from_config(config.copy())
+ y = layer(x, initial_state=s, constants=c)
+ model = keras.models.Model([x, s, c], y)
+ model.set_weights(weights)
+ y_np_2 = model.predict([x_np, s_np, c_np])
+ self.assertAllClose(y_np, y_np_2, atol=1e-4)
+
+ # verify that state is used
+ y_np_2_different_s = model.predict([x_np, s_np + 10., c_np])
+ with self.assertRaises(AssertionError):
+ self.assertAllClose(y_np, y_np_2_different_s, atol=1e-4)
+
+ with self.test_session():
+ # test flat list inputs
+ with keras.utils.CustomObjectScope(custom_objects):
+ layer = keras.layers.RNN.from_config(config.copy())
+ y = layer([x, s, c])
+ model = keras.models.Model([x, s, c], y)
+ model.set_weights(weights)
+ y_np_3 = model.predict([x_np, s_np, c_np])
+ self.assertAllClose(y_np, y_np_3, atol=1e-4)
+
+ def test_stacked_rnn_attributes(self):
+ cells = [keras.layers.LSTMCell(3),
+ keras.layers.LSTMCell(3, kernel_regularizer='l2')]
+ layer = keras.layers.RNN(cells)
+ layer.build((None, None, 5))
+
+ # Test regularization losses
+ assert len(layer.losses) == 1
+
+ # Test weights
+ assert len(layer.trainable_weights) == 6
+ cells[0].trainable = False
+ assert len(layer.trainable_weights) == 3
+ assert len(layer.non_trainable_weights) == 3
+
+ # Test `get_losses_for`
+ x = keras.Input((None, 5))
+ y = keras.backend.sum(x)
+ cells[0].add_loss(y, inputs=x)
+ assert layer.get_losses_for(x) == [y]
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py b/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py
index 9833485236..7edebdacd0 100644
--- a/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py
@@ -156,8 +156,10 @@ class SimpleRNNLayerTest(test.TestCase):
activity_regularizer='l1')
layer.build((None, None, 2))
self.assertEqual(len(layer.losses), 3)
- layer(keras.backend.variable(np.ones((2, 3, 2))))
- self.assertEqual(len(layer.losses), 4)
+
+ x = keras.backend.variable(np.ones((2, 3, 2)))
+ layer(x)
+ self.assertEqual(len(layer.get_losses_for(x)), 1)
def test_constraints_SimpleRNN(self):
embedding_dim = 4
@@ -175,9 +177,9 @@ class SimpleRNNLayerTest(test.TestCase):
recurrent_constraint=r_constraint,
bias_constraint=b_constraint)
layer.build((None, None, embedding_dim))
- self.assertEqual(layer.kernel.constraint, k_constraint)
- self.assertEqual(layer.recurrent_kernel.constraint, r_constraint)
- self.assertEqual(layer.bias.constraint, b_constraint)
+ self.assertEqual(layer.cell.kernel.constraint, k_constraint)
+ self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint)
+ self.assertEqual(layer.cell.bias.constraint, b_constraint)
def test_with_masking_layer_SimpleRNN(self):
layer_class = keras.layers.SimpleRNN
diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py
index acf0a5e179..b94bf8f0f6 100644
--- a/tensorflow/python/keras/layers/__init__.py
+++ b/tensorflow/python/keras/layers/__init__.py
@@ -134,6 +134,11 @@ from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool2D
from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool3D
# Recurrent layers.
+from tensorflow.python.keras._impl.keras.layers.recurrent import RNN
+from tensorflow.python.keras._impl.keras.layers.recurrent import StackedRNNCells
+from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNNCell
+from tensorflow.python.keras._impl.keras.layers.recurrent import GRUCell
+from tensorflow.python.keras._impl.keras.layers.recurrent import LSTMCell
from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNN
from tensorflow.python.keras._impl.keras.layers.recurrent import GRU
from tensorflow.python.keras._impl.keras.layers.recurrent import LSTM
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 7fa504e85e..8d6f863a4c 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1186,6 +1186,7 @@ cuda_py_test(
srcs = ["check_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
+ "//tensorflow/python/eager:context",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 76b80e60ea..612f2c0a72 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -505,7 +505,7 @@ class StridedSliceTest(test_util.TensorFlowTestCase):
_ = checker2[...]
_ = checker2[tuple()]
- def testInt64GPU(self):
+ def testFloatSlicedArrayAndInt64IndicesGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
with self.test_session(use_gpu=True, force_gpu=True):
@@ -516,6 +516,17 @@ class StridedSliceTest(test_util.TensorFlowTestCase):
s = array_ops.strided_slice(x, begin, end, strides)
self.assertAllEqual([3.], self.evaluate(s))
+ def testInt64SlicedArrayAndIndicesGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+ with self.test_session(use_gpu=True, force_gpu=True):
+ x = constant_op.constant([1, 2, 3], dtype=dtypes.int64)
+ begin = constant_op.constant([2], dtype=dtypes.int64)
+ end = constant_op.constant([3], dtype=dtypes.int64)
+ strides = constant_op.constant([1], dtype=dtypes.int64)
+ s = array_ops.strided_slice(x, begin, end, strides)
+ self.assertAllEqual([3], self.evaluate(s))
+
def testDegenerateSlices(self):
with self.test_session(use_gpu=True):
checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR)
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index ed859e3774..43785adcee 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -20,10 +20,13 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.platform import test
@@ -71,110 +74,178 @@ class AssertProperIterableTest(test.TestCase):
class AssertEqualTest(test.TestCase):
+ @test_util.run_in_graph_and_eager_modes()
def test_doesnt_raise_when_equal(self):
- with self.test_session():
+ small = constant_op.constant([1, 2], name="small")
+ with ops.control_dependencies([check_ops.assert_equal(small, small)]):
+ out = array_ops.identity(small)
+ self.evaluate(out)
+
+ def test_returns_none_with_eager(self):
+ with context.eager_mode():
small = constant_op.constant([1, 2], name="small")
- with ops.control_dependencies([check_ops.assert_equal(small, small)]):
- out = array_ops.identity(small)
- out.eval()
+ x = check_ops.assert_equal(small, small)
+ assert x is None
+ @test_util.run_in_graph_and_eager_modes()
def test_raises_when_greater(self):
- with self.test_session():
- # Static check
- static_small = constant_op.constant([1, 2], name="small")
- static_big = constant_op.constant([3, 4], name="big")
- with self.assertRaisesRegexp(ValueError, "fail"):
- check_ops.assert_equal(static_big, static_small, message="fail")
- # Dynamic check
- small = array_ops.placeholder(dtypes.int32, name="small")
- big = array_ops.placeholder(dtypes.int32, name="big")
- with ops.control_dependencies(
- [check_ops.assert_equal(
- big, small, message="fail")]):
- out = array_ops.identity(small)
- with self.assertRaisesOpError("fail.*big.*small"):
- out.eval(feed_dict={small: [1, 2], big: [3, 4]})
-
+ # Static check
+ static_small = constant_op.constant([1, 2], name="small")
+ static_big = constant_op.constant([3, 4], name="big")
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"):
+ check_ops.assert_equal(static_big, static_small, message="fail")
+
+ # Dynamic check
+ if context.in_graph_mode():
+ with self.test_session():
+ small = array_ops.placeholder(dtypes.int32, name="small")
+ big = array_ops.placeholder(dtypes.int32, name="big")
+ with ops.control_dependencies(
+ [check_ops.assert_equal(
+ big, small, message="fail")]):
+ out = array_ops.identity(small)
+ with self.assertRaisesOpError("fail.*big.*small"):
+ out.eval(feed_dict={small: [1, 2], big: [3, 4]})
+
+ def test_error_message_eager(self):
+ expected_error_msg_full = r"""big does not equal small
+Condition x == y did not hold.
+Indices of first 6 different values:
+\[\[0 0\]
+ \[1 1\]
+ \[2 0\]\]
+Corresponding x values:
+\[2 3 6\]
+Corresponding y values:
+\[20 30 60\]
+First 6 elements of x:
+\[2 2 3 3 6 6\]
+First 6 elements of y:
+\[20 2 3 30 60 6\]
+"""
+ expected_error_msg_short = r"""big does not equal small
+Condition x == y did not hold.
+Indices of first 2 different values:
+\[\[0 0\]
+ \[1 1\]\]
+Corresponding x values:
+\[2 3\]
+Corresponding y values:
+\[20 30\]
+First 2 elements of x:
+\[2 2\]
+First 2 elements of y:
+\[20 2\]
+"""
+ with context.eager_mode():
+ big = constant_op.constant([[2, 2], [3, 3], [6, 6]])
+ small = constant_op.constant([[20, 2], [3, 30], [60, 6]])
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ expected_error_msg_full):
+ check_ops.assert_equal(big, small, message="big does not equal small",
+ summarize=10)
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ expected_error_msg_short):
+ check_ops.assert_equal(big, small, message="big does not equal small",
+ summarize=2)
+
+ @test_util.run_in_graph_and_eager_modes()
def test_raises_when_less(self):
- with self.test_session():
- # Static check
- static_small = constant_op.constant([3, 1], name="small")
- static_big = constant_op.constant([4, 2], name="big")
- with self.assertRaisesRegexp(ValueError, "fail"):
- check_ops.assert_equal(static_big, static_small, message="fail")
- # Dynamic check
- small = array_ops.placeholder(dtypes.int32, name="small")
- big = array_ops.placeholder(dtypes.int32, name="big")
- with ops.control_dependencies([check_ops.assert_equal(small, big)]):
- out = array_ops.identity(small)
- with self.assertRaisesOpError("small.*big"):
- out.eval(feed_dict={small: [3, 1], big: [4, 2]})
+ # Static check
+ static_small = constant_op.constant([3, 1], name="small")
+ static_big = constant_op.constant([4, 2], name="big")
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"):
+ check_ops.assert_equal(static_big, static_small, message="fail")
+
+ # Dynamic check
+ if context.in_graph_mode():
+ with self.test_session():
+ small = array_ops.placeholder(dtypes.int32, name="small")
+ big = array_ops.placeholder(dtypes.int32, name="big")
+ with ops.control_dependencies([check_ops.assert_equal(small, big)]):
+ out = array_ops.identity(small)
+ with self.assertRaisesOpError("small.*big"):
+ out.eval(feed_dict={small: [3, 1], big: [4, 2]})
+ @test_util.run_in_graph_and_eager_modes()
def test_doesnt_raise_when_equal_and_broadcastable_shapes(self):
- with self.test_session():
- small = constant_op.constant([1, 2], name="small")
- small_2 = constant_op.constant([1, 2], name="small_2")
- with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
- out = array_ops.identity(small)
- out.eval()
+ small = constant_op.constant([[1, 2], [1, 2]], name="small")
+ small_2 = constant_op.constant([1, 2], name="small_2")
+ with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
+ out = array_ops.identity(small)
+ self.evaluate(out)
+ @test_util.run_in_graph_and_eager_modes()
def test_raises_when_equal_but_non_broadcastable_shapes(self):
- with self.test_session():
- small = constant_op.constant([1, 1, 1], name="small")
- small_2 = constant_op.constant([1, 1], name="small_2")
- with self.assertRaisesRegexp(ValueError, "must be"):
- with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
- out = array_ops.identity(small)
- out.eval()
+ small = constant_op.constant([1, 1, 1], name="small")
+ small_2 = constant_op.constant([1, 1], name="small_2")
+ # The exception in eager and non-eager mode is different because
+ # eager mode relies on shape check done as part of the C++ op, while
+ # graph mode does shape checks when creating the `Operation` instance.
+ with self.assertRaisesRegexp(
+ (errors.InvalidArgumentError, ValueError),
+ (r"Incompatible shapes: \[3\] vs. \[2\]|"
+ r"Dimensions must be equal, but are 3 and 2")):
+ with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
+ out = array_ops.identity(small)
+ self.evaluate(out)
+ @test_util.run_in_graph_and_eager_modes()
def test_doesnt_raise_when_both_empty(self):
- with self.test_session():
- larry = constant_op.constant([])
- curly = constant_op.constant([])
- with ops.control_dependencies([check_ops.assert_equal(larry, curly)]):
- out = array_ops.identity(larry)
- out.eval()
+ larry = constant_op.constant([])
+ curly = constant_op.constant([])
+ with ops.control_dependencies([check_ops.assert_equal(larry, curly)]):
+ out = array_ops.identity(larry)
+ self.evaluate(out)
class AssertNoneEqualTest(test.TestCase):
+ @test_util.run_in_graph_and_eager_modes()
def test_doesnt_raise_when_not_equal(self):
- with self.test_session():
- small = constant_op.constant([1, 2], name="small")
- big = constant_op.constant([10, 20], name="small")
- with ops.control_dependencies(
- [check_ops.assert_none_equal(big, small)]):
- out = array_ops.identity(small)
- out.eval()
-
+ small = constant_op.constant([1, 2], name="small")
+ big = constant_op.constant([10, 20], name="small")
+ with ops.control_dependencies(
+ [check_ops.assert_none_equal(big, small)]):
+ out = array_ops.identity(small)
+ self.evaluate(out)
+
+ @test_util.run_in_graph_and_eager_modes()
def test_raises_when_equal(self):
- with self.test_session():
- small = constant_op.constant([3, 1], name="small")
+ small = constant_op.constant([3, 1], name="small")
+ with self.assertRaisesOpError("x != y did not hold"):
with ops.control_dependencies(
[check_ops.assert_none_equal(small, small)]):
out = array_ops.identity(small)
- with self.assertRaisesOpError("x != y did not hold"):
- out.eval()
+ self.evaluate(out)
+ @test_util.run_in_graph_and_eager_modes()
def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self):
- with self.test_session():
- small = constant_op.constant([1, 2], name="small")
- big = constant_op.constant([3], name="big")
- with ops.control_dependencies(
- [check_ops.assert_none_equal(small, big)]):
- out = array_ops.identity(small)
- out.eval()
-
+ small = constant_op.constant([1, 2], name="small")
+ big = constant_op.constant([3], name="big")
+ with ops.control_dependencies(
+ [check_ops.assert_none_equal(small, big)]):
+ out = array_ops.identity(small)
+ self.evaluate(out)
+
+ @test_util.run_in_graph_and_eager_modes()
def test_raises_when_not_equal_but_non_broadcastable_shapes(self):
with self.test_session():
small = constant_op.constant([1, 1, 1], name="small")
big = constant_op.constant([10, 10], name="big")
- with self.assertRaisesRegexp(ValueError, "must be"):
+ # The exception in eager and non-eager mode is different because
+ # eager mode relies on shape check done as part of the C++ op, while
+ # graph mode does shape checks when creating the `Operation` instance.
+ with self.assertRaisesRegexp(
+ (ValueError, errors.InvalidArgumentError),
+ (r"Incompatible shapes: \[3\] vs. \[2\]|"
+ r"Dimensions must be equal, but are 3 and 2")):
with ops.control_dependencies(
[check_ops.assert_none_equal(small, big)]):
out = array_ops.identity(small)
- out.eval()
+ self.evaluate(out)
+ @test_util.run_in_graph_and_eager_modes()
def test_doesnt_raise_when_both_empty(self):
with self.test_session():
larry = constant_op.constant([])
@@ -182,62 +253,82 @@ class AssertNoneEqualTest(test.TestCase):
with ops.control_dependencies(
[check_ops.assert_none_equal(larry, curly)]):
out = array_ops.identity(larry)
- out.eval()
+ self.evaluate(out)
+
+ def test_returns_none_with_eager(self):
+ with context.eager_mode():
+ t1 = constant_op.constant([1, 2])
+ t2 = constant_op.constant([3, 4])
+ x = check_ops.assert_none_equal(t1, t2)
+ assert x is None
class AssertLessTest(test.TestCase):
+ @test_util.run_in_graph_and_eager_modes()
def test_raises_when_equal(self):
- with self.test_session():
- small = constant_op.constant([1, 2], name="small")
+ small = constant_op.constant([1, 2], name="small")
+ with self.assertRaisesOpError("failure message.*\n*.* x < y did not hold"):
with ops.control_dependencies(
[check_ops.assert_less(
- small, small, message="fail")]):
+ small, small, message="failure message")]):
out = array_ops.identity(small)
- with self.assertRaisesOpError("fail.*small.*small"):
- out.eval()
+ self.evaluate(out)
+ @test_util.run_in_graph_and_eager_modes()
def test_raises_when_greater(self):
- with self.test_session():
- small = constant_op.constant([1, 2], name="small")
- big = constant_op.constant([3, 4], name="big")
+ small = constant_op.constant([1, 2], name="small")
+ big = constant_op.constant([3, 4], name="big")
+ with self.assertRaisesOpError("x < y did not hold"):
with ops.control_dependencies([check_ops.assert_less(big, small)]):
out = array_ops.identity(small)
- with self.assertRaisesOpError("big.*small"):
- out.eval()
+ self.evaluate(out)
+ @test_util.run_in_graph_and_eager_modes()
def test_doesnt_raise_when_less(self):
- with self.test_session():
- small = constant_op.constant([3, 1], name="small")
- big = constant_op.constant([4, 2], name="big")
- with ops.control_dependencies([check_ops.assert_less(small, big)]):
- out = array_ops.identity(small)
- out.eval()
+ small = constant_op.constant([3, 1], name="small")
+ big = constant_op.constant([4, 2], name="big")
+ with ops.control_dependencies([check_ops.assert_less(small, big)]):
+ out = array_ops.identity(small)
+ self.evaluate(out)
+ @test_util.run_in_graph_and_eager_modes()
def test_doesnt_raise_when_less_and_broadcastable_shapes(self):
- with self.test_session():
- small = constant_op.constant([1], name="small")
- big = constant_op.constant([3, 2], name="big")
- with ops.control_dependencies([check_ops.assert_less(small, big)]):
- out = array_ops.identity(small)
- out.eval()
+ small = constant_op.constant([1], name="small")
+ big = constant_op.constant([3, 2], name="big")
+ with ops.control_dependencies([check_ops.assert_less(small, big)]):
+ out = array_ops.identity(small)
+ self.evaluate(out)
+ @test_util.run_in_graph_and_eager_modes()
def test_raises_when_less_but_non_broadcastable_shapes(self):
- with self.test_session():
- small = constant_op.constant([1, 1, 1], name="small")
- big = constant_op.constant([3, 2], name="big")
- with self.assertRaisesRegexp(ValueError, "must be"):
- with ops.control_dependencies([check_ops.assert_less(small, big)]):
- out = array_ops.identity(small)
- out.eval()
+ small = constant_op.constant([1, 1, 1], name="small")
+ big = constant_op.constant([3, 2], name="big")
+ # The exception in eager and non-eager mode is different because
+ # eager mode relies on shape check done as part of the C++ op, while
+ # graph mode does shape checks when creating the `Operation` instance.
+ with self.assertRaisesRegexp(
+ (ValueError, errors.InvalidArgumentError),
+ (r"Incompatible shapes: \[3\] vs. \[2\]|"
+ "Dimensions must be equal, but are 3 and 2")):
+ with ops.control_dependencies([check_ops.assert_less(small, big)]):
+ out = array_ops.identity(small)
+ self.evaluate(out)
+ @test_util.run_in_graph_and_eager_modes()
def test_doesnt_raise_when_both_empty(self):
- with self.test_session():
- larry = constant_op.constant([])
- curly = constant_op.constant([])
- with ops.control_dependencies([check_ops.assert_less(larry, curly)]):
- out = array_ops.identity(larry)
- out.eval()
+ larry = constant_op.constant([])
+ curly = constant_op.constant([])
+ with ops.control_dependencies([check_ops.assert_less(larry, curly)]):
+ out = array_ops.identity(larry)
+ self.evaluate(out)
+
+ def test_returns_none_with_eager(self):
+ with context.eager_mode():
+ t1 = constant_op.constant([1, 2])
+ t2 = constant_op.constant([3, 4])
+ x = check_ops.assert_less(t1, t2)
+ assert x is None
class AssertLessEqualTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index a21182beba..fc125daf38 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -2856,11 +2856,12 @@ class EagerTest(test.TestCase):
def testCond(self):
with context.eager_mode():
pred = math_ops.less(1, 2)
- fn1 = lambda: constant_op.constant(10)
- fn2 = lambda: constant_op.constant(20)
+ fn1 = lambda: [constant_op.constant(10)]
+ fn2 = lambda: [constant_op.constant(20)]
r = control_flow_ops.cond(pred, fn1, fn2)
self.assertAllEqual(r.numpy(), 10)
+ self.assertFalse(isinstance(r, list))
def testWhileLoop(self):
with context.eager_mode():
diff --git a/tensorflow/python/kernel_tests/distributions/multinomial_test.py b/tensorflow/python/kernel_tests/distributions/multinomial_test.py
index d62aca151a..e24e8ade73 100644
--- a/tensorflow/python/kernel_tests/distributions/multinomial_test.py
+++ b/tensorflow/python/kernel_tests/distributions/multinomial_test.py
@@ -281,10 +281,10 @@ class MultinomialTest(test.TestCase):
dist.variance(),
dist.stddev(),
])
- self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.01)
- self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.01)
- self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.01)
- self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.01)
+ self.assertAllClose(sample_mean_, analytic_mean, atol=0.01, rtol=0.01)
+ self.assertAllClose(sample_cov_, analytic_cov, atol=0.01, rtol=0.01)
+ self.assertAllClose(sample_var_, analytic_var, atol=0.01, rtol=0.01)
+ self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.01, rtol=0.01)
def testSampleUnbiasedNonScalarBatch(self):
with self.test_session() as sess:
diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py
index af5e23c926..5109ed98c9 100644
--- a/tensorflow/python/kernel_tests/gather_nd_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import variables
@@ -185,6 +186,9 @@ class GatherNdTest(test.TestCase):
self.assertAllEqual(expected.reshape([10, 10, 20]), gather_nd_val)
self.assertEqual([10, 10, 20], gather_nd_t.get_shape())
+ def assertIndexedSlices(self, t):
+ self.assertIsInstance(t, ops.IndexedSlices)
+
def testUnknownIndices(self):
params = constant_op.constant([[0, 1, 2]])
indices = array_ops.placeholder(dtypes.int32)
@@ -233,7 +237,8 @@ class GatherNdTest(test.TestCase):
grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
expected_grads = np.array([[3, 4], [1, 2]], dtype=np.float64)
with self.test_session(use_gpu=True):
- self.assertAllEqual(expected_grads, grads.eval())
+ self.assertIndexedSlices(grads)
+ self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval())
def testGradientsRank3Elements(self):
indices = constant_op.constant(
@@ -284,7 +289,8 @@ class GatherNdTest(test.TestCase):
[0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 3, 3, 3, 3, 3, 3, 3, 3]],
dtype=np.float64)
with self.test_session(use_gpu=True):
- self.assertAllEqual(expected_grads, grads.eval())
+ self.assertIndexedSlices(grads)
+ self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval())
class GatherNdOpBenchmark(test.Benchmark):
diff --git a/tensorflow/python/kernel_tests/iterator_ops_test.py b/tensorflow/python/kernel_tests/iterator_ops_test.py
index 60a44b5b14..b198fa1754 100644
--- a/tensorflow/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/kernel_tests/iterator_ops_test.py
@@ -17,12 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops import readers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,10 +33,13 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
@@ -54,6 +59,15 @@ class IteratorTest(test.TestCase):
with self.assertRaisesRegexp(LookupError, "No gradient defined"):
gradients_impl.gradients(value, [component, side])
+ def testCapturingStateInOneShotRaisesException(self):
+ var = variables.Variable(37.0, name="myvar")
+ dataset = (dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0])
+ .map(lambda x: x + var))
+ with self.assertRaisesRegexp(
+ ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support "
+ "datasets that capture stateful objects.+myvar"):
+ dataset.make_one_shot_iterator()
+
def testOneShotIterator(self):
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
@@ -533,6 +547,64 @@ class IteratorTest(test.TestCase):
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
+ def testIncorrectIteratorRestore(self):
+
+ def _path():
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _save_op(iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ _path(), parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(_path()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
+ def _build_range_dataset_graph():
+ start = 1
+ stop = 10
+ iterator = dataset_ops.Dataset.range(start,
+ stop).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = _save_op(iterator._iterator_resource)
+ restore_op = _restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ def _build_reader_dataset_graph():
+ filenames = ["test"] # Does not exist but we don't care in this test.
+ iterator = readers.FixedLengthRecordDataset(
+ filenames, 1, 0, 0).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next_op = iterator.get_next()
+ save_op = _save_op(iterator._iterator_resource)
+ restore_op = _restore_op(iterator._iterator_resource)
+ return init_op, get_next_op, save_op, restore_op
+
+ # Saving iterator for RangeDataset graph.
+ with ops.Graph().as_default() as g:
+ init_op, _, save_op, _ = _build_range_dataset_graph()
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(save_op)
+
+ # Attempt to restore the saved iterator into an IteratorResource of
+ # incompatible type. An iterator of RangeDataset has output type int64,
+ # while an iterator of FixedLengthRecordDataset has output type string.
+ # So an InvalidArgumentError should be raised by
+ # IteratorResource::set_iterator.
+ with ops.Graph().as_default() as g:
+ _, _, _, restore_op = _build_reader_dataset_graph()
+ with self.test_session(graph=g) as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(restore_op)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/range_dataset_op_test.py b/tensorflow/python/kernel_tests/range_dataset_op_test.py
index 3c1685c951..0c530522b8 100644
--- a/tensorflow/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/kernel_tests/range_dataset_op_test.py
@@ -17,15 +17,32 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
class RangeDatasetTest(test.TestCase):
+ def tearDown(self):
+ # Remove all checkpoint files.
+ prefix = self._iterator_checkpoint_prefix()
+ pattern = prefix + "*"
+ files = gfile.Glob(pattern)
+ map(gfile.Remove, files)
+
def testStop(self):
stop = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator()
@@ -151,6 +168,319 @@ class RangeDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def _iterator_checkpoint_prefix(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _save_op(self, iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ self._iterator_checkpoint_prefix(),
+ parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(self, iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
+ def testSaveRestore(self):
+
+ def _build_graph(start, stop):
+ iterator = dataset_ops.Dataset.range(start,
+ stop).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ # Saving and restoring in different sessions.
+ start = 2
+ stop = 10
+ break_point = 5
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, _, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Saving and restoring in same session.
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testRestoreWithoutBuildingDatasetGraph(self):
+
+ def _build_graph(start, stop, num_epochs):
+ dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ # Saving and restoring in different sessions.
+ start = 2
+ stop = 10
+ num_epochs = 5
+ break_point = 5
+ break_epoch = 3
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for _ in range(break_epoch):
+ for i in range(start, stop):
+ self.assertEqual(i, sess.run(get_next))
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ # Create an empty IteratorResource and restore the Iterator into it.
+ output_types = dtypes.int64
+ output_shapes = tensor_shape.scalar()
+ iterator = iterator_ops.Iterator.from_structure(output_types,
+ output_shapes)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ get_next = iterator.get_next()
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ for _ in range(break_epoch + 1, num_epochs):
+ for i in range(start, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testRestoreInModifiedGraph(self):
+
+ def _build_graph(start, stop):
+ dataset = dataset_ops.Dataset.range(start, stop)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ # Saving and restoring in different sessions.
+ start = 2
+ stop = 10
+ stop_1 = 8
+ break_point = 5
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ # Intentionally build a graph with a different value for stop to make sure
+ # the original dataset graph is actually getting loaded.
+ init_op, get_next, _, restore_op = _build_graph(start, stop_1)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testInitThenRestore(self):
+ # Note: Calling init_op before restore_op is redundant. This test just makes
+ # sure we do not fail if restore is called on an already initialized
+ # iterator resource.
+
+ def _build_graph(start, stop):
+ dataset = dataset_ops.Dataset.range(start, stop)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ # Saving and restoring in different sessions.
+ start = 2
+ stop = 10
+ break_point = 5
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, _, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testMultipleSaves(self):
+
+ def _build_graph(start, stop):
+ iterator = dataset_ops.Dataset.range(start,
+ stop).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ start = 2
+ stop = 10
+ break_point1 = 5
+ break_point2 = 7
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point1):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for i in range(break_point1, break_point2):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ break_point2 = 7
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for i in range(break_point2, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testSaveRestoreWithRepeat(self):
+
+ def _build_graph(start, stop, num_epochs):
+ iterator = dataset_ops.Dataset.range(
+ start, stop).repeat(num_epochs).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ start = 2
+ stop = 10
+ num_epochs = 5
+ break_range = 5
+ break_epoch = 3
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(
+ start, stop, num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for _ in range(break_epoch - 1):
+ for i in range(start, stop):
+ self.assertEqual(i, sess.run(get_next))
+ for i in range(start, break_range):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for i in range(break_range, stop):
+ self.assertEqual(i, sess.run(get_next))
+ for _ in range(break_epoch, num_epochs):
+ for i in range(start, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testSaveRestoreExhaustedIterator(self):
+
+ def _build_graph(start, stop, num_epochs):
+ iterator = dataset_ops.Dataset.range(
+ start, stop).repeat(num_epochs).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ start = 2
+ stop = 10
+ num_epochs = 5
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(
+ start, stop, num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for _ in range(num_epochs):
+ for i in range(start, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
index 70b6ce442e..c8e7333b4b 100644
--- a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
@@ -26,8 +26,13 @@ from tensorflow.python.data.ops import readers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@@ -267,6 +272,299 @@ class FixedLengthRecordReaderTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(iterator.get_next())
+ def _iterator_checkpoint_path(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _save_op(self, iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ self._iterator_checkpoint_path(),
+ parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(self, iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
+ def _build_iterator_graph(self, num_epochs):
+ filenames = self._createFiles()
+ dataset = (readers.FixedLengthRecordDataset(
+ filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
+ .repeat(num_epochs))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next_op = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next_op, save_op, restore_op
+
+ def _restore_iterator(self):
+ output_types = dtypes.string
+ output_shapes = tensor_shape.scalar()
+ iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
+ get_next = iterator.get_next()
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return restore_op, get_next
+
+ def testSaveRestore(self):
+ num_epochs = 10
+ epoch_break = 5
+ file_break = self._num_files // 2
+ record_break = self._num_records // 2
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch == epoch_break and f == file_break and
+ r == record_break):
+ sess.run(save_op)
+ break
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ else:
+ continue
+ break
+ else:
+ continue
+ break
+ else:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch < epoch_break or
+ (epoch == epoch_break and f < file_break) or
+ (epoch == epoch_break and f == file_break and
+ r < record_break)):
+ continue
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ def testInitThenRestore(self):
+ # Note: Calling init_op before restore_op is redundant. This test just makes
+ # sure we do not fail if restore is called on an already initialized
+ # iterator resource.
+ num_epochs = 10
+ epoch_break = 5
+ file_break = self._num_files // 2
+ record_break = self._num_records // 2
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch == epoch_break and f == file_break and
+ r == record_break):
+ sess.run(save_op)
+ break
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ else:
+ continue
+ break
+ else:
+ continue
+ break
+ else:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch < epoch_break or
+ (epoch == epoch_break and f < file_break) or
+ (epoch == epoch_break and f == file_break and
+ r < record_break)):
+ continue
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ def testRestoreInModifiedGraph(self):
+ num_epochs = 10
+ num_epochs_1 = 20
+ epoch_break = 5
+ file_break = self._num_files // 2
+ record_break = self._num_records // 2
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch == epoch_break and f == file_break and
+ r == record_break):
+ sess.run(save_op)
+ break
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ else:
+ continue
+ break
+ else:
+ continue
+ break
+ else:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs_1)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch < epoch_break or
+ (epoch == epoch_break and f < file_break) or
+ (epoch == epoch_break and f == file_break and
+ r < record_break)):
+ continue
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ def testRestoreWithoutBuildingDatasetGraph(self):
+ num_epochs = 10
+ epoch_break = 5
+ file_break = self._num_files // 2
+ record_break = self._num_records // 2
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch == epoch_break and f == file_break and
+ r == record_break):
+ sess.run(save_op)
+ break
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ else:
+ continue
+ break
+ else:
+ continue
+ break
+ else:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ with ops.Graph().as_default() as g:
+ restore_op, get_next_op = self._restore_iterator()
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch < epoch_break or
+ (epoch == epoch_break and f < file_break) or
+ (epoch == epoch_break and f == file_break and
+ r < record_break)):
+ continue
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ def testRestoreUnusedIterator(self):
+ num_epochs = 10
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ # Save unused iterator.
+ sess.run(save_op)
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for _ in range(num_epochs * self._num_files * self._num_records):
+ sess.run(get_next_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ def testRestoreExhaustedIterator(self):
+ num_epochs = 10
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for _ in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
class TFRecordDatasetTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index bd4b12b7e8..5396214956 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -117,6 +117,18 @@ class VariableScopeTest(test.TestCase):
w = variable_scope.get_variable("w", [])
self.assertEqual(w.dtype.base_dtype, dtypes.float16)
+ def testEagerVaribleStore(self):
+ with context.eager_mode():
+ store = variable_scope.EagerVariableStore()
+ with store.as_default():
+ v = variable_scope.get_variable("v", shape=(), trainable=True)
+ w = variable_scope.get_variable("w", shape=(), trainable=False)
+
+ self.assertTrue(v in store.variables())
+ self.assertTrue(w in store.variables())
+ self.assertTrue(v in store.trainable_variables())
+ self.assertFalse(w in store.trainable_variables())
+
@test_util.run_in_graph_and_eager_modes()
def testInitFromNonTensorValue(self):
v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32)
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index 4b3dadc112..43be08f8a1 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -181,6 +181,24 @@ class XentTest(test.TestCase):
print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)
+ def testGradientLabelWithV2(self):
+ with self.test_session():
+ l = constant_op.constant(
+ [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5],
+ shape=[3, 4],
+ dtype=dtypes.float64,
+ name="l")
+ f = constant_op.constant(
+ [0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, 0.1, 0.8, 2.7, 6.4],
+ shape=[3, 4],
+ dtype=dtypes.float64,
+ name="f")
+ x = nn_ops.softmax_cross_entropy_with_logits_v2(labels=l, logits=f,
+ name="xent")
+ err = gradient_checker.compute_gradient_error(l, [3, 4], x, [3])
+
+ self.assertLess(err, 5e-8)
+
def testSecondGradient(self):
with self.test_session() as sess:
l = constant_op.constant([0.0, 0.0, 1.0/3, 0.0,
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index db608aa79a..c5bf4c6080 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -401,10 +401,11 @@ class Layer(object):
"""
return input_shape
- def _make_unique_name(self, name_uid_map=None, avoid_names=None):
+ def _make_unique_name(self, name_uid_map=None, avoid_names=None,
+ namespace=''):
base_name = _to_snake_case(self.__class__.__name__)
name = _unique_layer_name(base_name, name_uid_map=name_uid_map,
- avoid_names=avoid_names)
+ avoid_names=avoid_names, namespace=namespace)
return (name, base_name)
def _set_scope(self, scope=None):
@@ -641,7 +642,7 @@ class Layer(object):
for output in output_list:
with ops.name_scope('ActivityRegularizer'):
activity_regularization = self._activity_regularizer(output)
- self.add_loss(activity_regularization)
+ self.add_loss(activity_regularization, inputs=inputs)
if not in_deferred_mode:
# TODO(fchollet): consider how masking will work with deferred mode.
@@ -2370,7 +2371,7 @@ def _get_default_graph_uid_map():
return name_uid_map
-def _unique_layer_name(name, name_uid_map=None, avoid_names=None):
+def _unique_layer_name(name, name_uid_map=None, avoid_names=None, namespace=''):
"""Makes a layer name (or arbitrary string) unique within a TensorFlow graph.
Arguments:
@@ -2379,6 +2380,9 @@ def _unique_layer_name(name, name_uid_map=None, avoid_names=None):
names. If None (default), uses a per-Graph dictionary.
avoid_names: An optional set or dict with names which should not be used. If
None (default) does not avoid any names.
+ namespace: Gets a name which is unique within the (graph, namespace). Layers
+ which are not Networks use a blank namespace and so get graph-global
+ names.
Returns:
Unique string name.
@@ -2396,6 +2400,7 @@ def _unique_layer_name(name, name_uid_map=None, avoid_names=None):
avoid_names = set()
proposed_name = None
while proposed_name is None or proposed_name in avoid_names:
- name_uid_map[name] += 1
- proposed_name = name + '_' + str(name_uid_map[name])
+ name_key = (namespace, name)
+ name_uid_map[name_key] += 1
+ proposed_name = name + '_' + str(name_uid_map[name_key])
return proposed_name
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index 71eff2f965..509ad5a7af 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -47,7 +47,7 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(layer.trainable_variables, [])
self.assertEqual(layer.non_trainable_variables, [])
if context.in_graph_mode():
- # updates, losses only suppported in GRAPH mode
+ # updates, losses only supported in GRAPH mode
self.assertEqual(layer.updates, [])
self.assertEqual(layer.losses, [])
self.assertEqual(layer.built, False)
@@ -574,6 +574,13 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(3, result['label'].numpy())
self.assertEqual(4.0, result['logits'].numpy())
+ def testActivityRegularizer(self):
+ regularizer = math_ops.reduce_sum
+ layer = base_layers.Layer(activity_regularizer=regularizer)
+ x = array_ops.placeholder('int32')
+ layer.apply(x)
+ self.assertEqual(len(layer.get_losses_for(x)), 1)
+
class NetworkTest(test.TestCase):
diff --git a/tensorflow/python/lib/core/strings.i b/tensorflow/python/lib/core/strings.i
index 938c13e30e..9d807e51be 100644
--- a/tensorflow/python/lib/core/strings.i
+++ b/tensorflow/python/lib/core/strings.i
@@ -40,7 +40,7 @@ limitations under the License.
// Returns true on success, false on failure.
bool _BytesToStringPiece(PyObject* obj, tensorflow::StringPiece* result) {
if (obj == Py_None) {
- result->clear();
+ *result = tensorflow::StringPiece();
} else {
char* ptr;
Py_ssize_t len;
@@ -48,7 +48,7 @@ bool _BytesToStringPiece(PyObject* obj, tensorflow::StringPiece* result) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
return false;
}
- result->set(ptr, len);
+ *result = tensorflow::StringPiece(ptr, len);
}
return true;
}
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 3c025881cb..87f8d14860 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -460,7 +460,11 @@ def _GatherNdGrad(op, grad):
ref = op.inputs[0]
indices = op.inputs[1]
ref_shape = array_ops.shape(ref, out_type=indices.dtype)
- ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
+ if indices.shape.ndims == 2 and indices.shape[-1].value == 1:
+ ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1),
+ ref_shape)
+ else:
+ ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
return [ref_grad, None]
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index f5f1278bfd..037ab4ff50 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1663,6 +1663,8 @@ def placeholder(dtype, shape=None, name=None):
print(sess.run(y, feed_dict={x: rand_array})) # Will succeed.
```
+ @compatibility{eager} Placeholders are not compatible with eager execution.
+
Args:
dtype: The type of elements in the tensor to be fed.
shape: The shape of the tensor to be fed (optional). If the shape is not
@@ -1672,7 +1674,14 @@ def placeholder(dtype, shape=None, name=None):
Returns:
A `Tensor` that may be used as a handle for feeding a value, but not
evaluated directly.
+
+ Raises:
+ RuntimeError: if eager execution is enabled
"""
+ if context.in_eager_mode():
+ raise RuntimeError("tf.placeholder() is not compatible with "
+ "eager execution.")
+
return gen_array_ops._placeholder(dtype=dtype, shape=shape, name=name)
@@ -1716,6 +1725,8 @@ def sparse_placeholder(dtype, shape=None, name=None):
print(sess.run(y, feed_dict={x: sp_value})) # Will succeed.
```
+ @compatibility{eager} Placeholders are not compatible with eager execution.
+
Args:
dtype: The type of `values` elements in the tensor to be fed.
shape: The shape of the tensor to be fed (optional). If the shape is not
@@ -1725,7 +1736,14 @@ def sparse_placeholder(dtype, shape=None, name=None):
Returns:
A `SparseTensor` that may be used as a handle for feeding a value, but not
evaluated directly.
+
+ Raises:
+ RuntimeError: if eager execution is enabled
"""
+ if context.in_eager_mode():
+ raise RuntimeError("tf.placeholder() is not compatible with "
+ "eager execution.")
+
shape_name = (name + "/shape") if name is not None else None
shape, rank = _normalize_sparse_shape(shape, shape_name)
if shape is None:
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index ceee009104..7e509f72c1 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -48,6 +48,7 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_util
@@ -96,10 +97,11 @@ def _maybe_constant_value_string(t):
def _assert_static(condition, data):
- """Raises a static ValueError with as much information as possible."""
+ """Raises a InvalidArgumentError with as much information as possible."""
if not condition:
data_static = [_maybe_constant_value_string(x) for x in data]
- raise ValueError('\n'.join(data_static))
+ raise errors.InvalidArgumentError(node_def=None, op=None,
+ message='\n'.join(data_static))
def assert_proper_iterable(values):
@@ -303,11 +305,60 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
Returns:
Op that raises `InvalidArgumentError` if `x == y` is False.
+ @compatibility{eager} returns None
+
+ Raises:
+ InvalidArgumentError if the check can be performed immediately and
+ `x == y` is False. The check can be performed immediately during
+ eager execution or if `x` and `y` are statically known.
"""
message = message or ''
with ops.name_scope(name, 'assert_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
+
+ if context.in_eager_mode():
+ eq = math_ops.equal(x, y)
+ condition = math_ops.reduce_all(eq)
+ if not condition:
+ # Prepare a message with first elements of x and y
+ summary_msg = ''
+ if summarize:
+ # reshape((-1,)) is the fastest way to get a flat array view.
+ x_np = x.numpy().reshape((-1,))
+ y_np = y.numpy().reshape((-1,))
+ x_sum = min(x_np.size, summarize)
+ y_sum = min(y_np.size, summarize)
+ summary_msg = ('First %d elements of x:\n%s\n'
+ 'First %d elements of y:\n%s\n' %
+ (x_sum, x_np[:x_sum],
+ y_sum, y_np[:y_sum]))
+
+ # Get the values that actually differed and their indices
+ mask = math_ops.logical_not(eq)
+ indices = array_ops.where(mask)
+ indices_np = indices.numpy()
+ x_vals = array_ops.boolean_mask(x, mask)
+ y_vals = array_ops.boolean_mask(y, mask)
+ diff_to_print = 0
+ if summarize:
+ diff_to_print = min(summarize, indices_np.size)
+
+ raise errors.InvalidArgumentError(
+ node_def=None, op=None,
+ message=('%s\nCondition x == y did not hold.\n'
+ 'Indices of first %s different values:\n%s\n'
+ 'Corresponding x values:\n%s\n'
+ 'Corresponding y values:\n%s\n'
+ '%s'
+ %
+ (message or '',
+ diff_to_print, indices_np[:diff_to_print],
+ x_vals.numpy().reshape((-1,))[:diff_to_print],
+ y_vals.numpy().reshape((-1,))[:diff_to_print],
+ summary_msg)))
+ return
+
if data is None:
data = [
message,
@@ -356,12 +407,19 @@ def assert_none_equal(
with ops.name_scope(name, 'assert_none_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
+ if context.in_eager_mode():
+ x_name = 'x'
+ y_name = 'y'
+ else:
+ x_name = x.name
+ y_name = y.name
+
if data is None:
data = [
message,
- 'Condition x != y did not hold for every single element:'
- 'x (%s) = ' % x.name, x,
- 'y (%s) = ' % y.name, y
+ 'Condition x != y did not hold for every single element:',
+ 'x (%s) = ' % x_name, x,
+ 'y (%s) = ' % y_name, y
]
condition = math_ops.reduce_all(math_ops.not_equal(x, y))
return control_flow_ops.Assert(condition, data, summarize=summarize)
@@ -397,11 +455,18 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_less', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
+ if context.in_eager_mode():
+ x_name = 'x'
+ y_name = 'y'
+ else:
+ x_name = x.name
+ y_name = y.name
+
if data is None:
data = [
message,
- 'Condition x < y did not hold element-wise:'
- 'x (%s) = ' % x.name, x, 'y (%s) = ' % y.name, y
+ 'Condition x < y did not hold element-wise:',
+ 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
]
condition = math_ops.reduce_all(math_ops.less(x, y))
return control_flow_ops.Assert(condition, data, summarize=summarize)
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 10d8e01304..d33d4cd597 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -60,6 +60,7 @@ from tensorflow.core.protobuf import control_flow_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
@@ -86,6 +87,29 @@ from tensorflow.python.util import tf_should_use
_basetuple = tuple
+def _summarize_eager(tensor, summarize=None):
+ """Returns a summarized string representation of eager `tensor`.
+
+ Args:
+ tensor: EagerTensor to summarize
+ summarize: Include these many first elements of `array`
+ """
+ # reshape((-1,)) is the fastest way to get a flat array view
+ if tensor._rank(): # pylint: disable=protected-access
+ flat = tensor.numpy().reshape((-1,))
+ lst = [str(x) for x in flat[:summarize]]
+ if len(lst) < flat.size:
+ lst.append("...")
+ else:
+ # tensor.numpy() returns a scalar for zero dimensional arrays
+ if summarize != 0:
+ lst = [str(tensor.numpy())]
+ else:
+ lst = []
+
+ return ", ".join(lst)
+
+
# pylint: disable=protected-access
@@ -98,7 +122,8 @@ def Assert(condition, data, summarize=None, name=None):
If `condition` evaluates to false, print the list of tensors in `data`.
`summarize` determines how many entries of the tensors to print.
- NOTE: To ensure that Assert executes, one usually attaches a dependency:
+ NOTE: In graph mode, to ensure that Assert executes, one usually attaches
+ a dependency:
```python
# Ensure maximum element of x is smaller or equal to 1
@@ -117,7 +142,21 @@ def Assert(condition, data, summarize=None, name=None):
assert_op: An `Operation` that, when executed, raises a
`tf.errors.InvalidArgumentError` if `condition` is not true.
@compatibility{eager} returns None.
+
+ Raises:
+ @compatibility{eager} `tf.errors.InvalidArgumentError` if `condition`
+ is not true
"""
+ if context.in_eager_mode():
+ if not condition:
+ xs = ops.convert_n_to_tensor(data)
+ data_str = [_summarize_eager(x, summarize) for x in xs]
+ raise errors.InvalidArgumentError(
+ node_def=None, op=None,
+ message="Expected '%s' to be true. Summarized data: %s" % (
+ condition, "\n".join(data_str)))
+ return
+
with ops.name_scope(name, "Assert", [condition, data]) as name:
xs = ops.convert_n_to_tensor(data)
if all([x.dtype in {dtypes.string, dtypes.int32} for x in xs]):
@@ -1838,8 +1877,8 @@ def cond(pred, true_fn=None, false_fn=None, strict=False, name=None,
with ops.name_scope(name, "cond", [pred]):
if context.in_eager_mode():
if pred:
- return true_fn()
- return false_fn()
+ return _UnpackIfSingleton(true_fn())
+ return _UnpackIfSingleton(false_fn())
# Add the Switch to the graph.
if isinstance(pred, bool):
diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index 477c0d1cb4..f037767cf4 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -22,8 +22,8 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import gen_ctc_ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_ctc_ops
from tensorflow.python.ops.nn_grad import _BroadcastMul
@@ -38,7 +38,8 @@ def ctc_loss(labels, inputs, sequence_length,
[A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber.
Connectionist Temporal Classification: Labeling Unsegmented Sequence Data
- with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA, pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf)
+ with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA,
+ pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf)
Input requirements:
@@ -108,9 +109,9 @@ def ctc_loss(labels, inputs, sequence_length,
See `core/ops/ctc_ops.cc` for more details.
inputs: 3-D `float` `Tensor`.
If time_major == False, this will be a `Tensor` shaped:
- `[batch_size x max_time x num_classes]`.
+ `[batch_size, max_time, num_classes]`.
If time_major == True (default), this will be a `Tensor` shaped:
- `[max_time x batch_size x num_classes]`.
+ `[max_time, batch_size, num_classes]`.
The logits.
sequence_length: 1-D `int32` vector, size `[batch_size]`.
The sequence lengths.
@@ -120,15 +121,18 @@ def ctc_loss(labels, inputs, sequence_length,
ignore_longer_outputs_than_inputs: Boolean. Default: False.
If True, sequences with longer outputs than inputs will be ignored.
time_major: The shape format of the `inputs` Tensors.
- If True, these `Tensors` must be shaped `[max_time, batch_size, num_classes]`.
- If False, these `Tensors` must be shaped `[batch_size, max_time, num_classes]`.
- Using `time_major = True` (default) is a bit more efficient because it avoids
- transposes at the beginning of the ctc_loss calculation. However, most
- TensorFlow data is batch-major, so by this function also accepts inputs
- in batch-major form.
+ If True, these `Tensors` must be shaped `[max_time, batch_size,
+ num_classes]`.
+ If False, these `Tensors` must be shaped `[batch_size, max_time,
+ num_classes]`.
+ Using `time_major = True` (default) is a bit more efficient because it
+ avoids transposes at the beginning of the ctc_loss calculation. However,
+ most TensorFlow data is batch-major, so by this function also accepts
+ inputs in batch-major form.
Returns:
- A 1-D `float` `Tensor`, size `[batch]`, containing the negative log probabilities.
+ A 1-D `float` `Tensor`, size `[batch]`, containing the negative log
+ probabilities.
Raises:
TypeError: if labels is not a `SparseTensor`.
@@ -198,7 +202,7 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
Args:
inputs: 3-D `float` `Tensor` sized
- `[max_time x batch_size x num_classes]`. The logits.
+ `[max_time, batch_size, num_classes]`. The logits.
sequence_length: 1-D `int32` vector containing sequence lengths,
having size `[batch_size]`.
merge_repeated: Boolean. Default: True.
@@ -207,7 +211,7 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
A tuple `(decoded, neg_sum_logits)` where
decoded: A single-element list. `decoded[0]`
is an `SparseTensor` containing the decoded outputs s.t.:
- `decoded.indices`: Indices matrix `(total_decoded_outputs x 2)`.
+ `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`.
The rows store: `[batch, time]`.
`decoded.values`: Values vector, size `(total_decoded_outputs)`.
The vector stores the decoded classes.
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 8c1ccc6840..f4561d1a83 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -191,12 +191,9 @@ def _embedding_lookup_and_transform(params,
(flat_ids - extras) // ids_per_partition)
# Emulate a conditional using a boolean indicator tensor
- is_in_first_extras_partitions = math_ops.cast(p_assignments < extras,
- flat_ids.dtype)
- new_ids = (is_in_first_extras_partitions * (flat_ids %
- (ids_per_partition + 1)) +
- (1 - is_in_first_extras_partitions) *
- ((flat_ids - extras) % ids_per_partition))
+ new_ids = array_ops.where(p_assignments < extras,
+ flat_ids % (ids_per_partition + 1),
+ (flat_ids - extras) % ids_per_partition)
else:
raise ValueError("Unrecognized partition strategy: " +
partition_strategy)
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 79af3ac117..ee1a00623a 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -74,6 +74,7 @@ See the @{$python/nn} guide.
@@softmax
@@log_softmax
@@softmax_cross_entropy_with_logits
+@@softmax_cross_entropy_with_logits_v2
@@sparse_softmax_cross_entropy_with_logits
@@weighted_cross_entropy_with_logits
@@embedding_lookup
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 557f39fb42..4b406ba840 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -420,7 +420,6 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
# grad_loss is the backprop for cost, and we multiply it with the gradients
# (which is output[1])
# grad_grad is the backprop for softmax gradient.
- # There is no gradient for the labels
#
# Second derivative is just softmax derivative w.r.t. logits.
softmax_grad = op.outputs[1]
@@ -436,15 +435,15 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
const_fill_value = tensor_util.constant_value(g)
return const_fill_value is not None and (const_fill_value == 0).all()
+ logits = op.inputs[0]
if grad_grad is not None and not IsZero(grad_grad):
- logits = op.inputs[0]
softmax = nn_ops.softmax(logits)
grad += ((grad_grad - array_ops.squeeze(
math_ops.matmul(grad_grad[:, None, :],
softmax[:, :, None]), axis=1)) * softmax)
- return grad, None
+ return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits))
@ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 7297d2f349..da037a7983 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -277,9 +277,6 @@ def _swish_shape(op):
return [op.inputs[0].shape]
-# Set noinline=True so that sigmoid(features) is re-computed during
-# backprop, and we can free the sigmoid(features) expression immediately
-# after use during the forward pass.
@function.Defun(shape_func=_swish_shape, func_name="swish_grad", noinline=True)
def _swish_grad(features, grad):
"""Gradient of Swish function defined below."""
@@ -289,6 +286,11 @@ def _swish_grad(features, grad):
return grad * activation_grad
+# Naively, x * tf.nn.sigmoid(x) requires keeping both x and sigmoid(x) around
+# for backprop, effectively doubling the tensor's memory consumption. We use a
+# @Defun decorator with noinline=True so that sigmoid(features) is re-computed
+# during backprop, and we can free the sigmoid(features) expression immediately
+# after use during the forward pass.
@function.Defun(
grad_func=_swish_grad,
shape_func=_swish_shape,
@@ -298,7 +300,7 @@ def swish(features):
# pylint: disable=g-doc-args
"""Computes the Swish activation function: `x * sigmoid(x)`.
- Source: "Swish: a Self-Gated Activation Function" (Ramachandran et al. 2017)
+ Source: "Searching for Activation Functions" (Ramachandran et al. 2017)
https://arxiv.org/abs/1710.05941
Args:
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index c4de2c7f00..61fa462988 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
+
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_nn_ops import *
@@ -40,6 +41,7 @@ from tensorflow.python.ops.gen_nn_ops import *
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.deprecation import deprecated_argument_lookup
+from tensorflow.python.util import deprecation
# Aliases for some automatically-generated names.
local_response_normalization = gen_nn_ops.lrn
@@ -1711,9 +1713,9 @@ def _ensure_xent_args(name, sentinel, labels, logits):
raise ValueError("Both labels and logits must be provided.")
-def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
- labels=None, logits=None,
- dim=-1, name=None):
+def softmax_cross_entropy_with_logits_v2(_sentinel=None, # pylint: disable=invalid-name
+ labels=None, logits=None,
+ dim=-1, name=None):
"""Computes softmax cross entropy between `logits` and `labels`.
Measures the probability error in discrete classification tasks in which the
@@ -1737,6 +1739,10 @@ def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid
`[batch_size, num_classes]` and the same dtype (either `float16`, `float32`,
or `float64`).
+ Backpropagation will happen into both `logits` and `labels`. To disallow
+ backpropagation into `labels`, pass label tensors through a `stop_gradients`
+ before feeding it to this function.
+
**Note that to avoid confusion, it is required to pass only named arguments to
this function.**
@@ -1758,57 +1764,123 @@ def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid
# could break users who call this with bad labels, but disregard the bad
# results.
- logits = ops.convert_to_tensor(logits)
- labels = ops.convert_to_tensor(labels)
- precise_logits = math_ops.cast(logits, dtypes.float32) if (
- logits.dtype == dtypes.float16) else logits
- # labels and logits must be of the same type
- labels = math_ops.cast(labels, precise_logits.dtype)
- input_rank = array_ops.rank(precise_logits)
- # For shape inference.
- shape = logits.get_shape()
+ with ops.name_scope(
+ name, "softmax_cross_entropy_with_logits", [logits, labels]) as name:
+ logits = ops.convert_to_tensor(logits, name="logits")
+ labels = ops.convert_to_tensor(labels, name="labels")
+ precise_logits = math_ops.cast(logits, dtypes.float32) if (
+ logits.dtype == dtypes.float16) else logits
+ # labels and logits must be of the same type
+ labels = math_ops.cast(labels, precise_logits.dtype)
+ input_rank = array_ops.rank(precise_logits)
+ # For shape inference.
+ shape = logits.get_shape()
- # Move the dim to the end if dim is not the last dimension.
- if dim is not -1:
- def _move_dim_to_end(tensor, dim_index, rank):
- return array_ops.transpose(tensor,
- array_ops.concat([
- math_ops.range(dim_index),
- math_ops.range(dim_index + 1, rank),
- [dim_index]
- ], 0))
+ # Move the dim to the end if dim is not the last dimension.
+ if dim is not -1:
+ def _move_dim_to_end(tensor, dim_index, rank):
+ return array_ops.transpose(tensor,
+ array_ops.concat([
+ math_ops.range(dim_index),
+ math_ops.range(dim_index + 1, rank),
+ [dim_index]
+ ], 0))
- precise_logits = _move_dim_to_end(precise_logits, dim, input_rank)
- labels = _move_dim_to_end(labels, dim, input_rank)
+ precise_logits = _move_dim_to_end(precise_logits, dim, input_rank)
+ labels = _move_dim_to_end(labels, dim, input_rank)
- input_shape = array_ops.shape(precise_logits)
+ input_shape = array_ops.shape(precise_logits)
- # Make precise_logits and labels into matrices.
- precise_logits = _flatten_outer_dims(precise_logits)
- labels = _flatten_outer_dims(labels)
+ # Make precise_logits and labels into matrices.
+ precise_logits = _flatten_outer_dims(precise_logits)
+ labels = _flatten_outer_dims(labels)
- # Do the actual op computation.
- # The second output tensor contains the gradients. We use it in
- # _CrossEntropyGrad() in nn_grad but not here.
- cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits(
- precise_logits, labels, name=name)
+ # Do the actual op computation.
+ # The second output tensor contains the gradients. We use it in
+ # _CrossEntropyGrad() in nn_grad but not here.
+ cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits(
+ precise_logits, labels, name=name)
- # The output cost shape should be the input minus dim.
- output_shape = array_ops.slice(input_shape, [0],
- [math_ops.subtract(input_rank, 1)])
- cost = array_ops.reshape(cost, output_shape)
+ # The output cost shape should be the input minus dim.
+ output_shape = array_ops.slice(input_shape, [0],
+ [math_ops.subtract(input_rank, 1)])
+ cost = array_ops.reshape(cost, output_shape)
- # Make shape inference work since reshape and transpose may erase its static
- # shape.
- if context.in_graph_mode() and shape is not None and shape.dims is not None:
- shape = shape.as_list()
- del shape[dim]
- cost.set_shape(shape)
+ # Make shape inference work since reshape and transpose may erase its static
+ # shape.
+ if context.in_graph_mode() and shape is not None and shape.dims is not None:
+ shape = shape.as_list()
+ del shape[dim]
+ cost.set_shape(shape)
- if logits.dtype == dtypes.float16:
- return math_ops.cast(cost, dtypes.float16)
- else:
- return cost
+ if logits.dtype == dtypes.float16:
+ return math_ops.cast(cost, dtypes.float16)
+ else:
+ return cost
+
+
+_XENT_DEPRECATION = """
+Future major versions of TensorFlow will allow gradients to flow
+into the labels input on backprop by default.
+
+See tf.nn.softmax_cross_entropy_with_logits_v2.
+"""
+
+
+@deprecation.deprecated(date=None, instructions=_XENT_DEPRECATION)
+def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
+ labels=None, logits=None,
+ dim=-1, name=None):
+ """Computes softmax cross entropy between `logits` and `labels`.
+
+ Measures the probability error in discrete classification tasks in which the
+ classes are mutually exclusive (each entry is in exactly one class). For
+ example, each CIFAR-10 image is labeled with one and only one label: an image
+ can be a dog or a truck, but not both.
+
+ **NOTE:** While the classes are mutually exclusive, their probabilities
+ need not be. All that is required is that each row of `labels` is
+ a valid probability distribution. If they are not, the computation of the
+ gradient will be incorrect.
+
+ If using exclusive `labels` (wherein one and only
+ one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`.
+
+ **WARNING:** This op expects unscaled logits, since it performs a `softmax`
+ on `logits` internally for efficiency. Do not call this op with the
+ output of `softmax`, as it will produce incorrect results.
+
+ `logits` and `labels` must have the same shape, e.g.
+ `[batch_size, num_classes]` and the same dtype (either `float16`, `float32`,
+ or `float64`).
+
+ Backpropagation will happen only into `logits`. To calculate a cross entropy
+ loss that allows backpropagation into both `logits` and `labels`, see
+ @{tf.nn.softmax_cross_entropy_with_logits_v2}.
+
+ **Note that to avoid confusion, it is required to pass only named arguments to
+ this function.**
+
+ Args:
+ _sentinel: Used to prevent positional parameters. Internal, do not use.
+ labels: Each row `labels[i]` must be a valid probability distribution.
+ logits: Unscaled log probabilities.
+ dim: The class dimension. Defaulted to -1 which is the last dimension.
+ name: A name for the operation (optional).
+
+ Returns:
+ A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the
+ softmax cross entropy loss.
+ """
+ _ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel,
+ labels, logits)
+
+ with ops.name_scope(
+ name, "softmax_cross_entropy_with_logits_sg", [logits, labels]) as name:
+ labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
+
+ return softmax_cross_entropy_with_logits_v2(
+ labels=labels, logits=logits, dim=dim, name=name)
def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 21c7ed361d..df66302402 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -134,6 +135,13 @@ def _infer_state_dtype(explicit_dtype, state):
return state.dtype
+def _maybe_tensor_shape_from_tensor(shape):
+ if isinstance(shape, ops.Tensor):
+ return tensor_shape.as_shape(tensor_util.constant_value(shape))
+ else:
+ return shape
+
+
# pylint: disable=unused-argument
def _rnn_step(
time, sequence_length, min_sequence_length, max_sequence_length,
@@ -715,18 +723,28 @@ def _dynamic_rnn_loop(cell,
with ops.name_scope("dynamic_rnn") as scope:
base_name = scope
- def _create_ta(name, dtype):
+ def _create_ta(name, element_shape, dtype):
return tensor_array_ops.TensorArray(dtype=dtype,
size=time_steps,
+ element_shape=element_shape,
tensor_array_name=base_name + name)
in_graph_mode = context.in_graph_mode()
if in_graph_mode:
- output_ta = tuple(_create_ta("output_%d" % i,
- _infer_state_dtype(dtype, state))
- for i in range(len(flat_output_size)))
- input_ta = tuple(_create_ta("input_%d" % i, flat_input[i].dtype)
- for i in range(len(flat_input)))
+ output_ta = tuple(
+ _create_ta(
+ "output_%d" % i,
+ element_shape=(tensor_shape.TensorShape([const_batch_size])
+ .concatenate(
+ _maybe_tensor_shape_from_tensor(out_size))),
+ dtype=_infer_state_dtype(dtype, state))
+ for i, out_size in enumerate(flat_output_size))
+ input_ta = tuple(
+ _create_ta(
+ "input_%d" % i,
+ element_shape=flat_input_i.shape[1:],
+ dtype=flat_input_i.dtype)
+ for i, flat_input_i in enumerate(flat_input))
input_ta = tuple(ta.unstack(input_)
for ta, input_ in zip(input_ta, flat_input))
else:
@@ -1007,6 +1025,7 @@ def raw_rnn(cell, loop_fn,
static_batch_size.merge_with(input_shape_i[0])
batch_size = static_batch_size.value
+ const_batch_size = batch_size
if batch_size is None:
batch_size = array_ops.shape(flat_input[0])[0]
@@ -1029,8 +1048,15 @@ def raw_rnn(cell, loop_fn,
flat_emit_ta = [
tensor_array_ops.TensorArray(
- dtype=dtype_i, dynamic_size=True, size=0, name="rnn_output_%d" % i)
- for i, dtype_i in enumerate(flat_emit_dtypes)]
+ dtype=dtype_i,
+ dynamic_size=True,
+ element_shape=(tensor_shape.TensorShape([const_batch_size])
+ .concatenate(
+ _maybe_tensor_shape_from_tensor(size_i))),
+ size=0,
+ name="rnn_output_%d" % i)
+ for i, (dtype_i, size_i)
+ in enumerate(zip(flat_emit_dtypes, flat_emit_size))]
emit_ta = nest.pack_sequence_as(structure=emit_structure,
flat_sequence=flat_emit_ta)
flat_zero_emit = [
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 92fa928eed..9a0ff75594 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1227,6 +1227,11 @@ class EagerVariableStore(object):
def variables(self):
return self._store._vars.values() # pylint: disable=protected-access
+ def trainable_variables(self):
+ # pylint: disable=protected-access
+ return [x for x in self._store._vars.values() if x._trainable]
+ # pylint: enable=protected-access
+
def get_variable(name,
shape=None,
diff --git a/tensorflow/python/pywrap_dlopen_global_flags.py b/tensorflow/python/pywrap_dlopen_global_flags.py
index 509fc2170c..411334f480 100644
--- a/tensorflow/python/pywrap_dlopen_global_flags.py
+++ b/tensorflow/python/pywrap_dlopen_global_flags.py
@@ -28,13 +28,12 @@ from __future__ import print_function
import ctypes
import sys
-# On UNIX-based platforms, pywrap_tensorflow is a SWIG-generated
-# python library that dynamically loads _pywrap_tensorflow.so. The
-# default mode for loading keeps all the symbol private and not
-# visible to other libraries that may be loaded. Setting the mode to
-# RTLD_GLOBAL to make the symbols visible, so that custom op libraries
-# imported using `tf.load_op_library()` can access symbols defined in
-# _pywrap_tensorflow.so.
+# On UNIX-based platforms, pywrap_tensorflow is a SWIG-generated python library
+# that dynamically loads _pywrap_tensorflow.so. The default mode for loading
+# keeps all the symbol private and not visible to other libraries that may be
+# loaded. Setting the mode to RTLD_GLOBAL to make the symbols visible, so that
+# custom op libraries imported using `tf.load_op_library()` can access symbols
+# defined in _pywrap_tensorflow.so.
_use_rtld_global = (hasattr(sys, 'getdlopenflags')
and hasattr(sys, 'setdlopenflags'))
if _use_rtld_global:
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 637f738fed..cbacf458a0 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -29,7 +29,7 @@ limitations under the License.
%rename("%s") TFE_Py_TapeWatch;
%rename("%s") TFE_Py_TapeDeleteTrace;
%rename("%s") TFE_Py_TapeRecordOperation;
-%rename("%s") TFE_Py_TapeExport;
+%rename("%s") TFE_Py_TapeGradient;
%rename("%s") TFE_NewContextOptions;
%rename("%s") TFE_ContextOptionsSetConfig;
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
@@ -125,7 +125,7 @@ limitations under the License.
SWIG_fail;
}
if (EagerTensor_CheckExact(elem)) {
- (*$1)[i] = EagerTensorHandle(elem);
+ (*$1)[i] = EagerTensor_Handle(elem);
} else {
SWIG_exception_fail(SWIG_TypeError,
"provided list of inputs contains objects other "
diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py
index 47a74e5abf..8716058e61 100644
--- a/tensorflow/python/tools/inspect_checkpoint.py
+++ b/tensorflow/python/tools/inspect_checkpoint.py
@@ -29,7 +29,8 @@ from tensorflow.python.platform import flags
FLAGS = None
-def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
+def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors,
+ all_tensor_names):
"""Prints tensors in a checkpoint file.
If no `tensor_name` is provided, prints the tensor names and shapes
@@ -41,14 +42,16 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
file_name: Name of the checkpoint file.
tensor_name: Name of the tensor in the checkpoint file to print.
all_tensors: Boolean indicating whether to print all tensors.
+ all_tensor_names: Boolean indicating whether to print all tensor names.
"""
try:
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
- if all_tensors:
+ if all_tensors or all_tensor_names:
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
print("tensor_name: ", key)
- print(reader.get_tensor(key))
+ if all_tensors:
+ print(reader.get_tensor(key))
elif not tensor_name:
print(reader.debug_string().decode("utf-8"))
else:
@@ -104,11 +107,14 @@ def parse_numpy_printoption(kv_str):
def main(unused_argv):
if not FLAGS.file_name:
print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
- "[--tensor_name=tensor_to_print]")
+ "[--tensor_name=tensor_to_print] "
+ "[--all_tensors] "
+ "[--all_tensor_names] "
+ "[--printoptions]")
sys.exit(1)
else:
print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
- FLAGS.all_tensors)
+ FLAGS.all_tensors, FLAGS.all_tensor_names)
if __name__ == "__main__":
@@ -131,6 +137,13 @@ if __name__ == "__main__":
default=False,
help="If True, print the values of all the tensors.")
parser.add_argument(
+ "--all_tensor_names",
+ nargs="?",
+ const=True,
+ type="bool",
+ default=False,
+ help="If True, print the names of all the tensors.")
+ parser.add_argument(
"--printoptions",
nargs="*",
type=parse_numpy_printoption,
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index af9f11bb07..1f6016a91b 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -536,6 +536,7 @@ class _MonitoredSession(object):
will return True.
Example usage:
+
```python
with tf.Graph().as_default():
c = tf.placeholder(dtypes.float32)
@@ -552,6 +553,7 @@ class _MonitoredSession(object):
while not session.should_stop():
a = session.run_step_fn(step_fn)
```
+
Hooks interact with the `run_with_hooks()` call inside the `step_fn`
as they do with a `MonitoredSession.run` call.
diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py
index a576547d5f..37733152e8 100644
--- a/tensorflow/python/util/tf_should_use.py
+++ b/tensorflow/python/util/tf_should_use.py
@@ -44,7 +44,7 @@ def _add_should_use_warning(x, fatal_error=False):
and is a very shallow wrapper for `x` which logs access into `x`.
"""
del fatal_error
- if x is None: # special corner case where x is None
+ if x is None or x == []: # pylint: disable=g-explicit-bool-comparison
return x
if context.in_eager_mode():
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 2094061b44..d78362d4fb 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -390,8 +390,8 @@ port::Status CudnnSupport::Init() {
<< DriverVersionStatusToString(result);
} else {
const auto& version = result.ValueOrDie();
- LOG(INFO) << "possibly insufficient driver version: "
- << DriverVersionToString(version);
+ LOG(ERROR) << "possibly insufficient driver version: "
+ << DriverVersionToString(version);
// OS X kernel driver does not report version accurately
#if !defined(__APPLE__)
if (std::get<0>(version) < 340) {
@@ -961,7 +961,8 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> {
if (!allocated.ok() ||
(state_memory = allocated.ValueOrDie()) == nullptr) {
string error_msg =
- port::StrCat("Fail to allocate Cudnn dropout state memory");
+ port::StrCat("Failed to allocate Cudnn dropout state memory of ",
+ state_sizes_in_bytes, " bytes.");
status_ = port::Status(port::error::UNKNOWN, error_msg);
LOG(ERROR) << error_msg;
return;
@@ -970,7 +971,10 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> {
status = wrap::cudnnSetDropoutDescriptor(parent_, handle_, cudnn_handle,
dropout, state_memory.opaque(),
state_memory.size(), seed);
- CUDNN_RETURN_IF_FAIL(status, "Failed to set dropout descriptor");
+ CUDNN_RETURN_IF_FAIL(
+ status, port::StrCat(
+ "Failed to set dropout descriptor with state memory size: ",
+ state_memory.size(), " bytes."));
}
~CudnnDropoutDescriptor() {
@@ -1475,7 +1479,8 @@ bool CreateRnnWorkspace(Stream* stream, CUDAExecutor* parent,
auto allocated =
workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
if (!allocated.ok() || (*workspace = allocated.ValueOrDie()) == nullptr) {
- LOG(ERROR) << "Failed to allocate RNN workspace";
+ LOG(ERROR) << port::StrCat("Failed to allocate RNN workspace of ",
+ workspace_size_in_bytes, " bytes.");
return false;
}
} else {
@@ -1552,7 +1557,8 @@ bool CudnnSupport::DoRnnForwardImpl(
stream, reserve_space_size_in_bytes);
if (!allocated.ok() ||
(reserve_space = allocated.ValueOrDie()) == nullptr) {
- LOG(ERROR) << "Fail to allocate RNN reserve space";
+ LOG(ERROR) << "Failed to allocate RNN reserve space of "
+ << reserve_space_size_in_bytes << " bytes.";
return false;
}
}
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 16c3386e15..a3ba363469 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -168,26 +168,30 @@ WIN_COPTS = [
# LINT.IfChange
def tf_copts():
- return (if_not_windows([
- "-DEIGEN_AVOID_STL_ARRAY",
- "-Iexternal/gemmlowp",
- "-Wno-sign-compare",
- "-ftemplate-depth=900",
- "-fno-exceptions",
- ]) + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1", "-fopenmp",]) + if_android_arm(
- ["-mfpu=neon"]) + if_linux_x86_64(["-msse3"]) + select({
- clean_dep("//tensorflow:android"): [
- "-std=c++11",
- "-DTF_LEAN_BINARY",
- "-O2",
- "-Wno-narrowing",
- "-fomit-frame-pointer",
- ],
- clean_dep("//tensorflow:darwin"): [],
- clean_dep("//tensorflow:windows"): WIN_COPTS,
- clean_dep("//tensorflow:windows_msvc"): WIN_COPTS,
- clean_dep("//tensorflow:ios"): ["-std=c++11"],
- "//conditions:default": ["-pthread"]
+ return (
+ if_not_windows([
+ "-DEIGEN_AVOID_STL_ARRAY",
+ "-Iexternal/gemmlowp",
+ "-Wno-sign-compare",
+ "-fno-exceptions",
+ "-ftemplate-depth=900"])
+ + if_cuda(["-DGOOGLE_CUDA=1"])
+ + if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML", "-fopenmp",])
+ + if_android_arm(["-mfpu=neon"])
+ + if_linux_x86_64(["-msse3"])
+ + select({
+ clean_dep("//tensorflow:android"): [
+ "-std=c++11",
+ "-DTF_LEAN_BINARY",
+ "-O2",
+ "-Wno-narrowing",
+ "-fomit-frame-pointer",
+ ],
+ clean_dep("//tensorflow:darwin"): [],
+ clean_dep("//tensorflow:windows"): WIN_COPTS,
+ clean_dep("//tensorflow:windows_msvc"): WIN_COPTS,
+ clean_dep("//tensorflow:ios"): ["-std=c++11"],
+ "//conditions:default": ["-pthread"]
}))
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt
new file mode 100644
index 0000000000..f5ed263f0e
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt
@@ -0,0 +1,54 @@
+path: "tensorflow.estimator.BaselineClassifier"
+tf_class {
+ is_instance: "<class \'tensorflow.python.estimator.canned.baseline.BaselineClassifier\'>"
+ is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "config"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "model_dir"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "model_fn"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "params"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\'], "
+ }
+ member_method {
+ name: "evaluate"
+ argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "export_savedmodel"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "get_variable_names"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_variable_value"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "latest_checkpoint"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "predict"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "train"
+ argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt
new file mode 100644
index 0000000000..61a29942c5
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt
@@ -0,0 +1,54 @@
+path: "tensorflow.estimator.BaselineRegressor"
+tf_class {
+ is_instance: "<class \'tensorflow.python.estimator.canned.baseline.BaselineRegressor\'>"
+ is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "config"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "model_dir"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "model_fn"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "params"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\'], "
+ }
+ member_method {
+ name: "evaluate"
+ argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "export_savedmodel"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "get_variable_names"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_variable_value"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "latest_checkpoint"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "predict"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "train"
+ argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt
index ef93a61bd8..cdc367b99e 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt
@@ -1,6 +1,14 @@
path: "tensorflow.estimator"
tf_module {
member {
+ name: "BaselineClassifier"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "BaselineRegressor"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "DNNClassifier"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
new file mode 100644
index 0000000000..763184899c
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -0,0 +1,179 @@
+path: "tensorflow.keras.layers.GRUCell"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.GRUCell\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
index 9237399254..889f2cbc23 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
@@ -1,15 +1,35 @@
path: "tensorflow.keras.layers.GRU"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.GRU\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.Recurrent\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.RNN\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
+ name: "activation"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
+ name: "bias_constraint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "bias_initializer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "bias_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dropout"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "dtype"
mtype: "<type \'property\'>"
}
@@ -18,6 +38,10 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "implementation"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -34,6 +58,18 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "kernel_constraint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "kernel_initializer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "kernel_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "losses"
mtype: "<type \'property\'>"
}
@@ -66,10 +102,34 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "recurrent_activation"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_constraint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_dropout"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_initializer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "scope_name"
mtype: "<type \'property\'>"
}
member {
+ name: "states"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -78,10 +138,18 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "units"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
+ name: "use_bias"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "variables"
mtype: "<type \'property\'>"
}
@@ -91,7 +159,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], "
+ argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\'], "
}
member_method {
name: "add_loss"
@@ -138,10 +206,6 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_constants"
- argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "get_initial_state"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
@@ -159,7 +223,7 @@ tf_class {
}
member_method {
name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "get_output_at"
@@ -182,10 +246,6 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "preprocess_input"
- argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "reset_states"
argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -193,8 +253,4 @@ tf_class {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
- member_method {
- name: "step"
- argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None"
- }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
new file mode 100644
index 0000000000..4ce7c34f6c
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
@@ -0,0 +1,179 @@
+path: "tensorflow.keras.layers.LSTMCell"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.LSTMCell\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
index 20935e2f99..e1a1d0d58e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
@@ -1,15 +1,35 @@
path: "tensorflow.keras.layers.LSTM"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.LSTM\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.Recurrent\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.RNN\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
+ name: "activation"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
+ name: "bias_constraint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "bias_initializer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "bias_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dropout"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "dtype"
mtype: "<type \'property\'>"
}
@@ -18,6 +38,10 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "implementation"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -34,6 +58,18 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "kernel_constraint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "kernel_initializer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "kernel_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "losses"
mtype: "<type \'property\'>"
}
@@ -66,10 +102,34 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "recurrent_activation"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_constraint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_dropout"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_initializer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "scope_name"
mtype: "<type \'property\'>"
}
member {
+ name: "states"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -78,10 +138,22 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "unit_forget_bias"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "units"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
+ name: "use_bias"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "variables"
mtype: "<type \'property\'>"
}
@@ -91,7 +163,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], "
+ argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\'], "
}
member_method {
name: "add_loss"
@@ -138,10 +210,6 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_constants"
- argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "get_initial_state"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
@@ -159,7 +227,7 @@ tf_class {
}
member_method {
name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "get_output_at"
@@ -182,10 +250,6 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "preprocess_input"
- argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "reset_states"
argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -193,8 +257,4 @@ tf_class {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
- member_method {
- name: "step"
- argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None"
- }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
new file mode 100644
index 0000000000..c7c9b10f22
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -0,0 +1,191 @@
+path: "tensorflow.keras.layers.RNN"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.RNN\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "states"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'activity_regularizer\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\', \'constants\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
new file mode 100644
index 0000000000..10c7f8867c
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
@@ -0,0 +1,179 @@
+path: "tensorflow.keras.layers.SimpleRNNCell"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.SimpleRNNCell\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
index f4148fcc23..588df21088 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
@@ -1,15 +1,35 @@
path: "tensorflow.keras.layers.SimpleRNN"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.SimpleRNN\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.Recurrent\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.RNN\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
+ name: "activation"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
+ name: "bias_constraint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "bias_initializer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "bias_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dropout"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "dtype"
mtype: "<type \'property\'>"
}
@@ -34,6 +54,18 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "kernel_constraint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "kernel_initializer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "kernel_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "losses"
mtype: "<type \'property\'>"
}
@@ -66,10 +98,30 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "recurrent_constraint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_dropout"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_initializer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "scope_name"
mtype: "<type \'property\'>"
}
member {
+ name: "states"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -78,10 +130,18 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "units"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
+ name: "use_bias"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "variables"
mtype: "<type \'property\'>"
}
@@ -91,7 +151,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], "
+ argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'False\', \'False\', \'False\', \'False\', \'False\'], "
}
member_method {
name: "add_loss"
@@ -138,10 +198,6 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_constants"
- argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "get_initial_state"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
@@ -159,7 +215,7 @@ tf_class {
}
member_method {
name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "get_output_at"
@@ -182,10 +238,6 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "preprocess_input"
- argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "reset_states"
argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -193,8 +245,4 @@ tf_class {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
- member_method {
- name: "step"
- argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None"
- }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
new file mode 100644
index 0000000000..5779e41342
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -0,0 +1,183 @@
+path: "tensorflow.keras.layers.StackedRNNCells"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.StackedRNNCells\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "state_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'cells\'], varargs=None, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
index 8466c3e039..fe336c4be5 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
@@ -141,6 +141,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "GRUCell"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "GaussianDropout"
mtype: "<type \'type\'>"
}
@@ -209,6 +213,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "LSTMCell"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "Lambda"
mtype: "<type \'type\'>"
}
@@ -273,6 +281,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "RNN"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "RepeatVector"
mtype: "<type \'type\'>"
}
@@ -293,6 +305,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "SimpleRNNCell"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "SpatialDropout1D"
mtype: "<type \'type\'>"
}
@@ -305,6 +321,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "StackedRNNCells"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "ThresholdedReLU"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
index 1e9d28ca74..ebd9c079b5 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
@@ -289,6 +289,10 @@ tf_module {
argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'dim\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'-1\', \'None\'], "
}
member_method {
+ name: "softmax_cross_entropy_with_logits_v2"
+ argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'dim\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'-1\', \'None\'], "
+ }
+ member_method {
name: "softplus"
argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index 2b9aec6c31..db02f6ef10 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -147,6 +147,38 @@ BAZEL_TARGET="//tensorflow/... -//tensorflow/compiler/..."
if [[ -n "$TF_SKIP_CONTRIB_TESTS" ]]; then
BAZEL_TARGET="$BAZEL_TARGET -//tensorflow/contrib/..."
+else
+ BAZEL_TARGET="${BAZEL_TARGET} -//tensorflow/contrib/lite/..."
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:context_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:framework"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:interpreter_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:model_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/toco:toco"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:simple_memory_arena_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:string_util_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:activations_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:add_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:basic_rnn_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:concatenation_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:conv_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:depthwise_conv_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:embedding_lookup_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:embedding_lookup_sparse_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:fully_connected_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/testing:generated_examples_zip_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:hashtable_lookup_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:local_response_norm_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:lsh_projection_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:lstm_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:l2norm_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:mul_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:pooling_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:reshape_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:resize_bilinear_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:skip_gram_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:softmax_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:space_to_depth_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:svdf_test"
fi
TUT_TEST_DATA_DIR="/tmp/tf_tutorial_test_data"
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index f1c207f9b6..404a9a6b62 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -98,7 +98,8 @@ do_pylint() {
"^tensorflow/contrib/eager/python/evaluator\.py.*\[E0202.*method-hidden "\
"^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\
"^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\
-"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable"
+"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
+"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition"
echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""
@@ -400,9 +401,14 @@ cmd_status(){
}
# Run bazel build --nobuild to test the validity of the BUILD files
+# TODO(mikecase): Remove TF Lite exclusion from this list. Exclusion is
+# necessary since the @androidsdk WORKSPACE dependency is commented
+# out by default in TF WORKSPACE file.
do_bazel_nobuild() {
BUILD_TARGET="//tensorflow/..."
- BUILD_CMD="bazel build --nobuild ${BAZEL_FLAGS} ${BUILD_TARGET}"
+ BUILD_TARGET="${BUILD_TARGET} -//tensorflow/contrib/lite/java/demo/app/src/main/..."
+ BUILD_TARGET="${BUILD_TARGET} -//tensorflow/contrib/lite/schema/..."
+ BUILD_CMD="bazel build --nobuild ${BAZEL_FLAGS} -- ${BUILD_TARGET}"
${BUILD_CMD}
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
index 5de5a379ac..df6016504c 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
@@ -33,4 +33,35 @@ yes "" | $PYTHON_BIN_PATH configure.py
bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test -k \
--jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
--test_output=errors -- \
- //tensorflow/contrib/...
+ //tensorflow/contrib/... \
+ -//tensorflow/contrib/lite/... \
+ //tensorflow/contrib/lite:context_test \
+ //tensorflow/contrib/lite:framework \
+ //tensorflow/contrib/lite:interpreter_test \
+ //tensorflow/contrib/lite:model_test \
+ //tensorflow/contrib/lite/toco:toco \
+ //tensorflow/contrib/lite:simple_memory_arena_test \
+ //tensorflow/contrib/lite:string_util_test \
+ //tensorflow/contrib/lite/kernels:activations_test \
+ //tensorflow/contrib/lite/kernels:add_test \
+ //tensorflow/contrib/lite/kernels:basic_rnn_test \
+ //tensorflow/contrib/lite/kernels:concatenation_test \
+ //tensorflow/contrib/lite/kernels:conv_test \
+ //tensorflow/contrib/lite/kernels:depthwise_conv_test \
+ //tensorflow/contrib/lite/kernels:embedding_lookup_test \
+ //tensorflow/contrib/lite/kernels:embedding_lookup_sparse_test \
+ //tensorflow/contrib/lite/kernels:fully_connected_test \
+ //tensorflow/contrib/lite/testing:generated_examples_zip_test \
+ //tensorflow/contrib/lite/kernels:hashtable_lookup_test \
+ //tensorflow/contrib/lite/kernels:local_response_norm_test \
+ //tensorflow/contrib/lite/kernels:lsh_projection_test \
+ //tensorflow/contrib/lite/kernels:lstm_test \
+ //tensorflow/contrib/lite/kernels:l2norm_test \
+ //tensorflow/contrib/lite/kernels:mul_test \
+ //tensorflow/contrib/lite/kernels:pooling_test \
+ //tensorflow/contrib/lite/kernels:reshape_test \
+ //tensorflow/contrib/lite/kernels:resize_bilinear_test \
+ //tensorflow/contrib/lite/kernels:skip_gram_test \
+ //tensorflow/contrib/lite/kernels:softmax_test \
+ //tensorflow/contrib/lite/kernels:space_to_depth_test \
+ //tensorflow/contrib/lite/kernels:svdf_test
diff --git a/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh b/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh
index 8042522ef8..ddaaddc917 100755
--- a/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh
+++ b/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh
@@ -34,4 +34,4 @@ bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac \
--test_timeout 300,450,1200,3600 \
--test_size_filters=small,medium \
--jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \
- //tensorflow/contrib/...
+ //tensorflow/contrib/... -//tensorflow/contrib/lite/...
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index 20e1dcd085..1a0145b078 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -83,6 +83,11 @@ ENV CI_BUILD_PYTHON python
RUN tensorflow/tools/ci_build/builds/configured CPU \
bazel build -c opt --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \
+ # For optimized builds appropriate for the hardware platform of your choosing, uncomment below...
+ # For ivy-bridge or sandy-bridge
+ # --copt=-march="ivybridge" \
+ # for haswell, broadwell, or skylake
+ # --copt=-march="haswell" \
tensorflow/tools/pip_package:build_pip_package && \
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/pip && \
pip --no-cache-dir install --upgrade /tmp/pip/tensorflow-*.whl && \
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index c6e577223f..a3ab40ceef 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -158,6 +158,9 @@ sh_binary(
"//tensorflow/contrib/graph_editor:graph_editor_pip",
"//tensorflow/contrib/keras:keras",
"//tensorflow/contrib/labeled_tensor:labeled_tensor_pip",
+ "//tensorflow/contrib/lite/toco:toco",
+ "//tensorflow/contrib/lite/toco/python:toco_wrapper",
+ "//tensorflow/contrib/lite/toco/python:toco_from_protos",
"//tensorflow/contrib/ndlstm:ndlstm",
"//tensorflow/contrib/nn:nn_py",
"//tensorflow/contrib/predictor:predictor_pip",
diff --git a/tensorflow/tools/pip_package/MANIFEST.in b/tensorflow/tools/pip_package/MANIFEST.in
index ef6cf56421..86c5e4776d 100644
--- a/tensorflow/tools/pip_package/MANIFEST.in
+++ b/tensorflow/tools/pip_package/MANIFEST.in
@@ -4,6 +4,7 @@ recursive-include * *.so
recursive-include * *.dll
recursive-include * *.lib
recursive-include * *.csv
+recursive-include tensorflow/aux-bin *
recursive-include tensorflow/include/tensorflow *.h
recursive-include tensorflow/include/Eigen *
recursive-include tensorflow/include/external *
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index cbf06a97d0..8249703ba7 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -137,6 +137,9 @@ function main() {
fi
fi
fi
+ # Install toco as a binary in aux-bin.
+ mkdir "${TMPDIR}/tensorflow/aux-bin"
+ cp bazel-bin/tensorflow/contrib/lite/toco/toco ${TMPDIR}/tensorflow/aux-bin/
fi
# protobuf pip package doesn't ship with header files. Copy the headers
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 0c54300e06..a493c6f2aa 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -69,6 +69,8 @@ if sys.version_info < (3, 4):
# pylint: disable=line-too-long
CONSOLE_SCRIPTS = [
'freeze_graph = tensorflow.python.tools.freeze_graph:main',
+ 'toco_from_protos = tensorflow.contrib.lite.toco.python.toco_from_protos:main',
+ 'toco = tensorflow.contrib.lite.toco.python.toco_wrapper:main',
'saved_model_cli = tensorflow.python.tools.saved_model_cli:main',
# We need to keep the TensorBoard command, even though the console script
# is now declared by the tensorboard pip package. If we remove the
@@ -188,7 +190,6 @@ headers = (list(find_files('*.h', 'tensorflow/core')) +
list(find_files('*', 'external/eigen_archive')) +
list(find_files('*.h', 'external/nsync/public')))
-
setup(
name=project_name,
version=_VERSION.replace('-', ''),
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index afcae6eade..2c9f067882 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -1,21 +1,24 @@
# TensorFlow external dependencies that can be loaded in WORKSPACE files.
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
+
load("//third_party/sycl:sycl_configure.bzl", "sycl_configure")
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
-load("@io_bazel_rules_closure//closure/private:java_import_external.bzl",
- "java_import_external")
+load(
+ "@io_bazel_rules_closure//closure/private:java_import_external.bzl",
+ "java_import_external",
+)
load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external")
load("//third_party/py:python_configure.bzl", "python_configure")
-load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl",
- "arm_compiler_configure")
-
+load(
+ "//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl",
+ "arm_compiler_configure",
+)
def _is_windows(repository_ctx):
"""Returns true if the host operating system is windows."""
return repository_ctx.os.name.lower().find("windows") != -1
-
def _get_env_var(repository_ctx, name):
"""Find an environment variable."""
if name in repository_ctx.os.environ:
@@ -23,7 +26,6 @@ def _get_env_var(repository_ctx, name):
else:
return None
-
# Parse the bazel version string from `native.bazel_version`.
def _parse_bazel_version(bazel_version):
# Remove commit from version.
@@ -39,7 +41,6 @@ def _parse_bazel_version(bazel_version):
version_tuple += (str(number),)
return version_tuple
-
# Check that a specific bazel version is being used.
def check_version(bazel_version):
if "bazel_version" not in dir(native):
@@ -56,11 +57,9 @@ def check_version(bazel_version):
fail("\nCurrent Bazel version is {}, expected at least {}\n".format(
native.bazel_version, bazel_version))
-
def _repos_are_siblings():
return Label("@foo//bar").workspace_root.startswith("../")
-
# Temporary workaround to support including TensorFlow as a submodule until this
# use-case is supported in the next Bazel release.
def _temp_workaround_http_archive_impl(repo_ctx):
@@ -73,9 +72,7 @@ def _temp_workaround_http_archive_impl(repo_ctx):
if repo_ctx.attr.patch_file != None:
_apply_patch(repo_ctx, repo_ctx.attr.patch_file)
-
temp_workaround_http_archive = repository_rule(
- implementation = _temp_workaround_http_archive_impl,
attrs = {
"build_file": attr.label(),
"repository": attr.string(),
@@ -84,6 +81,7 @@ temp_workaround_http_archive = repository_rule(
"sha256": attr.string(default = ""),
"strip_prefix": attr.string(default = ""),
},
+ implementation = _temp_workaround_http_archive_impl,
)
# Executes specified command with arguments and calls 'fail' if it exited with
@@ -95,7 +93,6 @@ def _execute_and_check_ret_code(repo_ctx, cmd_and_args):
+ "Stderr: {3}").format(" ".join(cmd_and_args), result.return_code,
result.stdout, result.stderr))
-
# Apply a patch_file to the repository root directory
# Runs 'patch -p1'
def _apply_patch(repo_ctx, patch_file):
@@ -113,7 +110,6 @@ def _apply_patch(repo_ctx, patch_file):
cmd = [bazel_sh, "-c", " ".join(cmd)]
_execute_and_check_ret_code(repo_ctx, cmd)
-
# Download the repository and apply a patch to its root
def _patched_http_archive_impl(repo_ctx):
repo_ctx.download_and_extract(
@@ -122,9 +118,7 @@ def _patched_http_archive_impl(repo_ctx):
stripPrefix=repo_ctx.attr.strip_prefix)
_apply_patch(repo_ctx, repo_ctx.attr.patch_file)
-
patched_http_archive = repository_rule(
- implementation = _patched_http_archive_impl,
attrs = {
"patch_file": attr.label(),
"build_file": attr.label(),
@@ -133,9 +127,9 @@ patched_http_archive = repository_rule(
"sha256": attr.string(default = ""),
"strip_prefix": attr.string(default = ""),
},
+ implementation = _patched_http_archive_impl,
)
-
# If TensorFlow is linked as a submodule.
# path_prefix is no longer used.
# tf_repo_name is thought to be under consideration.
@@ -448,11 +442,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
native.http_archive(
name = "nsync",
urls = [
- "https://mirror.bazel.build/github.com/google/nsync/archive/4fc8ff3e7626c5f24bc9674438d8257f0ffc226c.tar.gz",
- # "https://github.com/google/nsync/archive/4fc8ff3e7626c5f24bc9674438d8257f0ffc226c.tar.gz",
+ "https://mirror.bazel.build/github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz",
+ # "https://github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz",
],
- sha256 = "ffbbe828f3d0bef75462e34801de5cea31d10aa63eaa42a4ed74c46521bdfd58",
- strip_prefix = "nsync-4fc8ff3e7626c5f24bc9674438d8257f0ffc226c",
+ sha256 = "e3bd4555415ace511338fc27e595351738eea4e9006f1612b76c82914770716b",
+ strip_prefix = "nsync-93815892dddafe9146a5f7e7042281d59d0f4323",
)
native.http_archive(
@@ -821,3 +815,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
"https://github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz",
],
)
+
+ native.new_http_archive(
+ name = "tflite_mobilenet",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
+ sha256 = "23f814d1c076bdf03715dfb6cab3713aa4fbdf040fd5448c43196bd2e97a4c1b",
+ urls = [
+ "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip"
+ ],
+ )
diff --git a/third_party/flatbuffers/flatbuffers.BUILD b/third_party/flatbuffers/flatbuffers.BUILD
index a426db0c50..e1563103c8 100644
--- a/third_party/flatbuffers/flatbuffers.BUILD
+++ b/third_party/flatbuffers/flatbuffers.BUILD
@@ -104,6 +104,10 @@ cc_binary(
"grpc/",
"include/",
],
+ linkopts = [
+ "-lm",
+ "-ldl",
+ ],
deps = [
":flatc_library",
],
diff --git a/third_party/tflite_mobilenet.BUILD b/third_party/tflite_mobilenet.BUILD
new file mode 100644
index 0000000000..75663eff48
--- /dev/null
+++ b/third_party/tflite_mobilenet.BUILD
@@ -0,0 +1,13 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "model_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "BUILD",
+ ],
+ ),
+)
diff --git a/tools/bazel.rc b/tools/bazel.rc
index f609efe188..2447acd70a 100644
--- a/tools/bazel.rc
+++ b/tools/bazel.rc
@@ -24,13 +24,5 @@ build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true
build --spawn_strategy=standalone
-test --spawn_strategy=standalone
-run --spawn_strategy=standalone
-
build --genrule_strategy=standalone
-test --genrule_strategy=standalone
-run --genrule_strategy=standalone
-
build -c opt
-test -c opt
-run -c opt